Skip to content

Commit

Permalink
Implement Div + Div to common denominators
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Jul 6, 2023
1 parent e557f1b commit 07068af
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 72 deletions.
33 changes: 14 additions & 19 deletions src/fintamath/expressions/binary/DivExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,12 @@ bool DivExpression::isNeg(const ArgumentPtr &expr) {

ArgumentPtr DivExpression::sumSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
if (auto res = sumMulSimplify(lhs, rhs)) {
simplifyChild(res);
return res;
}

if (auto [lhsRes, rhsRes] = mulSumSimplify(lhs, rhs); lhsRes) {
ArgumentPtr res = makeExpr(Add(), lhsRes, rhsRes);
simplifyChild(res);
postSimplifyChild(res);
return res;
}

Expand Down Expand Up @@ -265,14 +264,15 @@ ArgumentPtr DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const Argument
remainderVect.emplace_back(child);
}
}

if (resultVect.empty()) {
return {};
}

resultVect.emplace_back(makeExpr(Div(), makeExpr(Add(), remainderVect), rhs));

ArgumentPtr result = makeExpr(Add(), resultVect);
simplifyChild(result);
postSimplifyChild(result);
return result;
}

Expand Down Expand Up @@ -313,7 +313,9 @@ ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const Argument
result.emplace_back(divExpr);
}

return makeExpr(Add(), result);
ArgumentPtr res = makeExpr(Add(), result);
postSimplifyChild(res);
return res;
}

std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
Expand Down Expand Up @@ -410,25 +412,20 @@ ArgumentPtr DivExpression::polynomSimplify(const IFunction & /*func*/, const Arg
if (const auto &lhsExpr = cast<IExpression>(lhs)) {
if (is<Add>(lhsExpr->getFunction())) {
result = numeratorSumSimplify(lhsExpr->getChildren(), rhs);
postSimplifyChild(result);
}

if (is<Mul>(lhsExpr->getFunction())) {
else if (is<Mul>(lhsExpr->getFunction())) {
result = numeratorMulSimplify(lhsExpr->getChildren(), rhs);
postSimplifyChild(result);
}
}

if (result) {
simplifyChild(result);
return result;
}

if (const auto &rhsExpr = cast<IExpression>(rhs)) {
else if (const auto &rhsExpr = cast<IExpression>(rhs)) {
if (is<Add>(rhsExpr->getFunction())) {
result = denominatorSumSimplify(lhs, rhs, rhsExpr->getChildren());
}
}

simplifyChild(result);
postSimplifyChild(result);
return result;
}

Expand Down Expand Up @@ -537,17 +534,15 @@ ArgumentPtr DivExpression::denominatorSumSimplify(const ArgumentPtr &lhs, const

ArgumentsPtrVector numeratorChildren = multiplicator;
numeratorChildren.emplace_back(lhs);

ArgumentPtr numerator = makeExpr(Mul(), numeratorChildren);
simplifyChild(numerator);

ArgumentsPtrVector denominatorChildren = multiplicator;
denominatorChildren.emplace_back(rhs);

ArgumentPtr denominator = makeExpr(Mul(), denominatorChildren);
simplifyChild(denominator);

return makeExpr(Div(), numerator, denominator);
ArgumentPtr res = makeExpr(Div(), numerator, denominator);
postSimplifyChild(res);
return res;
}

ArgumentPtr DivExpression::denominatorMulSimplify(const ArgumentsPtrVector &rhsChildren) {
Expand Down
37 changes: 23 additions & 14 deletions src/fintamath/expressions/polynomial/AddExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,7 @@ AddExpression::SimplifyFunctionsVector AddExpression::getFunctionsForSimplify()
AddExpression::SimplifyFunctionsVector AddExpression::getFunctionsForPreSimplify() const {
static const AddExpression::SimplifyFunctionsVector simplifyFunctions = {
&AddExpression::simplifyNegations, //
};
return simplifyFunctions;
}

AddExpression::SimplifyFunctionsVector AddExpression::getFunctionsForPostSimplify() const {
static const AddExpression::SimplifyFunctionsVector simplifyFunctions = {
&AddExpression::sumDivisions, //
&AddExpression::sumDivisions, //
};
return simplifyFunctions;
}
Expand Down Expand Up @@ -278,6 +272,7 @@ std::shared_ptr<const IExpression> AddExpression::mulToLogarithm(const Arguments

ArgumentPtr AddExpression::sumRates(const IFunction & /*func*/, const ArgumentPtr &lhsChild,
const ArgumentPtr &rhsChild) {

auto [lhsChildRate, lhsChildValue] = getRateValuePair(lhsChild);
auto [rhsChildRate, rhsChildValue] = getRateValuePair(rhsChild);

Expand All @@ -294,18 +289,32 @@ ArgumentPtr AddExpression::sumDivisions(const IFunction & /*func*/, const Argume
std::shared_ptr<const IExpression> lhsExpr = cast<IExpression>(lhsChild);
std::shared_ptr<const IExpression> rhsExpr = cast<IExpression>(rhsChild);

ArgumentPtr res;

if (lhsExpr && is<Div>(lhsExpr->getFunction()) && rhsExpr && is<Div>(rhsExpr->getFunction())) {
if (*lhsExpr->getChildren().back() == *rhsExpr->getChildren().back()) {
ArgumentPtr divLhs = makeExpr(Add(), lhsExpr->getChildren().front(), rhsExpr->getChildren().front());
simplifyChild(divLhs);

ArgumentPtr divRhs = lhsExpr->getChildren().back();

return makeExpr(Div(), divLhs, divRhs);
ArgumentPtr numerator = makeExpr(Add(), lhsExpr->getChildren().front(), rhsExpr->getChildren().front());
ArgumentPtr denominator = lhsExpr->getChildren().back();
res = makeExpr(Div(), numerator, denominator);
}
else {
ArgumentPtr lhsNumberatorMulDenominator =
makeExpr(Mul(), lhsExpr->getChildren().front(), rhsExpr->getChildren().back());
ArgumentPtr rhsNumeratorMulDenominator =
makeExpr(Mul(), rhsExpr->getChildren().front(), lhsExpr->getChildren().back());
ArgumentPtr numerator = makeExpr(Add(), lhsNumberatorMulDenominator, rhsNumeratorMulDenominator);
ArgumentPtr denominator = makeExpr(Mul(), lhsExpr->getChildren().back(), rhsExpr->getChildren().back());
res = makeExpr(Div(), numerator, denominator);
}
}
else if (rhsExpr && is<Div>(rhsExpr->getFunction())) {
ArgumentPtr lhsMulDenominator = makeExpr(Mul(), lhsChild, rhsExpr->getChildren().back());
ArgumentPtr numerator = makeExpr(Add(), rhsExpr->getChildren().front(), lhsMulDenominator);
ArgumentPtr denominator = rhsExpr->getChildren().back();
res = makeExpr(Div(), numerator, denominator);
}

return {};
return res;
}

}
2 changes: 0 additions & 2 deletions src/fintamath/expressions/polynomial/AddExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ class AddExpression : public IPolynomExpressionCRTP<AddExpression> {

SimplifyFunctionsVector getFunctionsForPreSimplify() const override;

SimplifyFunctionsVector getFunctionsForPostSimplify() const override;

std::string operatorChildToString(const ArgumentPtr &inChild, const ArgumentPtr &prevChild) const override;

/**
Expand Down
66 changes: 29 additions & 37 deletions tests/src/expressions/ExpressionTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,26 +209,6 @@ TEST(ExpressionTests, stringConstructorTest) {
EXPECT_EQ(Expression("1*a").toString(), "a");
EXPECT_EQ(Expression("a*1").toString(), "a");

EXPECT_EQ(Expression("(a/b)(c/d)").toString(), "(a c)/(b d)");
EXPECT_EQ(Expression("(ab/2)(ad/3)").toString(), "1/6 a^2 b d");
EXPECT_EQ(Expression("(-a)(-b)").toString(), "a b");
EXPECT_EQ(Expression("(a)(-b)").toString(), "-a b");
EXPECT_EQ(Expression("(-a)(b)").toString(), "-a b");
EXPECT_EQ(Expression("(5/3 b)/a").toString(), "(5 b)/(3 a)");
EXPECT_EQ(Expression("(a b)/(a b)").toString(), "1");
EXPECT_EQ(Expression("(a b)/1").toString(), "a b");
EXPECT_EQ(Expression("(a b)/-1").toString(), "-a b");
EXPECT_EQ(Expression("(a b)/-2").toString(), "-1/2 a b");
EXPECT_EQ(Expression("(a b)/(-a - b)").toString(), "-b + b^2/(a + b)");
EXPECT_EQ(Expression("(x^5)/(x - y)").toString(), "x^4 + x^3 y + x^2 y^2 + x y^3 + y^4 + y^5/(x - y)");
EXPECT_EQ(Expression("(3 x + 5/9)/(2y - 9/x + 3/2 x + 1/2 + 2 y / x)").toString(),
"2 - (16 y - 72)/(6 x^2 + 8 x y + 2 x + 8 y - 36) + (-16 x^2 y)/(6 x^3 + 8 x^2 y + 2 x^2 + 8 x y - 36 x) + "
"(-16 x)/(54 x^2 + 72 x y + 18 x + 72 y - 324)");
EXPECT_EQ(Expression("(a/x + b/(y+3/r)/4)/(3+t/5)").toString(),
"(5 a)/(t x + 15 x) + (25 b r)/(20 r t y + 300 r y + 60 t + 900)");
EXPECT_EQ(Expression("(x/a - (b+5)/(y-8/(12 y))/4)/(8-a/5)").toString(),
"-(300 b y + 1500 y)/(-240 a y^2 + 160 a + 9600 y^2 - 6400) + (5 x)/(-a^2 + 40 a)");

EXPECT_EQ(Expression("0^a").toString(), "0");
EXPECT_EQ(Expression("(a b)^0").toString(), "1");
EXPECT_EQ(Expression("(a + b)^-1").toString(), "1/(a + b)");
Expand Down Expand Up @@ -322,8 +302,6 @@ TEST(ExpressionTests, stringConstructorTest) {
EXPECT_EQ(Expression("2/(a + 2) + b/(a + 2)").toString(), "(b + 2)/(a + 2)");
EXPECT_EQ(Expression("c * 2^(a + 2) + b^(a + 2)").toString(), "b^(a + 2) + 2^(a + 2) c");
EXPECT_EQ(Expression("2^(a + 2) * b^(a + 2)").toString(), "b^(a + 2) 2^(a + 2)");
EXPECT_EQ(Expression("5/(a+b) + 5/(2a+b) + 5/(a+b)").toString(), "5/(2 a + b) + 10/(a + b)");
EXPECT_EQ(Expression("(x+y)/(a+b) + 5/(2a+b) + (x+2y)/(a+b)").toString(), "(2 x + 3 y)/(a + b) + 5/(2 a + b)");

EXPECT_EQ(Expression("(4x^2 - 5x - 21) / (x - 3)").toString(), "4 x + 7");
EXPECT_EQ(Expression("(3x^3 - 5x^2 + 10x - 3) / (3x + 1)").toString(), "x^2 - 2 x + 4 - 7/(3 x + 1)");
Expand All @@ -339,6 +317,35 @@ TEST(ExpressionTests, stringConstructorTest) {
EXPECT_EQ(Expression("(6x^8 - 7x^6 + 9x^4 - 4x^2 + 8) / (2x^3 - x^2 + 3x - 1)").toString(),
"3 x^5 + 3/2 x^4 - 29/4 x^3 - 35/8 x^2 + 223/16 x + 317/32 + (-1289 x^2 - 505 x + 573)/(64 x^3 - 32 x^2 + "
"96 x - 32)");
EXPECT_EQ(Expression("(2 a^3 + 5 a^2 b + 4 a b^2 + b^3)/(25 a^2 + 40 a b + 15 b^2)").toString(),
"2/25 a + 9/125 b + (-2 a b^2 - 2 b^3)/(625 a^2 + 1000 a b + 375 b^2)");
EXPECT_EQ(Expression("(25 a^2 + 40 a b + 15 b^2)/(2 a^3 + 5 a^2 b + 4 a b^2 + b^3)").toString(),
"(2 x + 3 y)/(a + b) + 5/(2 a + b)");
EXPECT_EQ(Expression("(x^2 + 2x + 1)/(x^3 + 3x^2 + 3x + 1)").toString(), "1/(x + 1)");
EXPECT_EQ(Expression("5/(a+b) + 5/(2a+b) + 5/(a+b)").toString(), "5/(2 a + b) + 10/(a + b)");
EXPECT_EQ(Expression("(x+y)/(a+b) + 5/(2a+b) + (x+2y)/(a+b)").toString(), "(2 x + 3 y)/(a + b) + 5/(2 a + b)");
EXPECT_EQ(Expression("(a/b)(c/d)").toString(), "(a c)/(b d)");
EXPECT_EQ(Expression("(ab/2)(ad/3)").toString(), "1/6 a^2 b d");
EXPECT_EQ(Expression("(-a)(-b)").toString(), "a b");
EXPECT_EQ(Expression("(a)(-b)").toString(), "-a b");
EXPECT_EQ(Expression("(-a)(b)").toString(), "-a b");
EXPECT_EQ(Expression("(5/3 b)/a").toString(), "(5 b)/(3 a)");
EXPECT_EQ(Expression("(a b)/(a b)").toString(), "1");
EXPECT_EQ(Expression("(a b)/1").toString(), "a b");
EXPECT_EQ(Expression("(a b)/-1").toString(), "-a b");
EXPECT_EQ(Expression("(a b)/-2").toString(), "-1/2 a b");
EXPECT_EQ(Expression("(a b)/(-a - b)").toString(), "-b + b^2/(a + b)");
EXPECT_EQ(Expression("(x^5)/(x - y)").toString(), "x^4 + x^3 y + x^2 y^2 + x y^3 + y^4 + y^5/(x - y)");
EXPECT_EQ(Expression("(3 x + 5/9)/(2y - 9/x + 3/2 x + 1/2 + 2 y / x)").toString(),
"2 - (16 y - 72)/(6 x^2 + 8 x y + 2 x + 8 y - 36) + (-16 x^2 y)/(6 x^3 + 8 x^2 y + 2 x^2 + 8 x y - 36 x) + "
"(-16 x)/(54 x^2 + 72 x y + 18 x + 72 y - 324)");
EXPECT_EQ(Expression("(a/x + b/(y+3/r)/4)/(3+t/5)").toString(),
"(5 a)/(t x + 15 x) + (25 b r)/(20 r t y + 300 r y + 60 t + 900)");
EXPECT_EQ(Expression("(x/a - (b+5)/(y-8/(12 y))/4)/(8-a/5)").toString(),
"-(300 b y + 1500 y)/(-240 a y^2 + 160 a + 9600 y^2 - 6400) + (5 x)/(-a^2 + 40 a)");
EXPECT_EQ(Expression("(a + b + c^2) / ((a + b + c^3) / (5/2 * (a + b) / (3/b + c/2)))").toString(),
"5 c + (5 a^2 b + 10 a b^2 - 30 a c + 5 b^3 + 180)/(a b c + 6 a + b^2 c + b c^4 + 6 b + 6 c^3) + (-5 c^4 - "
"30)/(a + b + c^3)");
EXPECT_EQ(Expression("( (2xy)/(x^2 - y^2) + (x - y)/(2x + 2y) ) * (2x)/(x + y) + y/(y - x)").toString(), "1");

// TODO! implement this
Expand Down Expand Up @@ -728,21 +735,6 @@ TEST(ExpressionTests, stringConstructorLargeTest) {
"x^7 y^23 + 593775 x^6 y^24 - 142506 x^5 y^25 + 27405 x^4 y^26 - 4060 x^3 y^27 + 435 x^2 y^28 - 30 x y^29 + "
"y^30");

EXPECT_EQ(
Expression("(a + b + c^2) / ((a + b + c^3) / (5/2 * (a + b) / (3/b + c/2)))").toString(),
"5 c - (2500 a^6 b^2 c^5 + 7500 a^5 b^3 c^5 + 7500 a^4 b^4 c^5 + 2500 a^3 b^5 c^5)/(500 a^7 b^2 c + 3000 a^7 b + "
"2000 a^6 b^3 c + 500 a^6 b^2 c^4 + 12000 a^6 b^2 + 3000 a^6 b c^3 + 3000 a^5 b^4 c + 1500 a^5 b^3 c^4 + 18000 "
"a^5 b^3 + 9000 a^5 b^2 c^3 + 2000 a^4 b^5 c + 1500 a^4 b^4 c^4 + 12000 a^4 b^4 + 9000 a^4 b^3 c^3 + 500 a^3 b^6 "
"c + 500 a^3 b^5 c^4 + 3000 a^3 b^5 + 3000 a^3 b^4 c^3) - (18750 a^4 c + 37500 a^3 b c + 18750 a^2 b^2 c)/(625 "
"a^4 b c + 3750 a^4 + 1250 a^3 b^2 c + 625 a^3 b c^4 + 7500 a^3 b + 3750 a^3 c^3 + 625 a^2 b^3 c + 625 a^2 b^2 "
"c^4 + 3750 a^2 b^2 + 3750 a^2 b c^3) - (-3000 a^3 b c^4 - 6000 a^2 b^2 c^4 - 3000 a b^3 c^4)/(100 a^5 b c + 600 "
"a^5 + 300 a^4 b^2 c + 100 a^4 b c^4 + 1800 a^4 b + 600 a^4 c^3 + 300 a^3 b^3 c + 200 a^3 b^2 c^4 + 1800 a^3 b^2 "
"+ 1200 a^3 b c^3 + 100 a^2 b^4 c + 100 a^2 b^3 c^4 + 600 a^2 b^3 + 600 a^2 b^2 c^3) - (187500 a^3 c^4 + 375000 "
"a^2 b c^4 + 187500 a b^2 c^4)/(6250 a^4 b c + 37500 a^4 + 12500 a^3 b^2 c + 6250 a^3 b c^4 + 75000 a^3 b + "
"37500 a^3 c^3 + 6250 a^2 b^3 c + 6250 a^2 b^2 c^4 + 37500 a^2 b^2 + 37500 a^2 b c^3) + (3125 a^5 b^2 + 9375 a^4 "
"b^3 + 9375 a^3 b^4 + 3125 a^2 b^5)/(625 a^4 b^2 c + 3750 a^4 b + 1250 a^3 b^3 c + 625 a^3 b^2 c^4 + 7500 a^3 "
"b^2 + 3750 a^3 b c^3 + 625 a^2 b^4 c + 625 a^2 b^3 c^4 + 3750 a^2 b^3 + 3750 a^2 b^2 c^3)");

EXPECT_EQ(
Expression("sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin("
"sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(sin("
Expand Down

0 comments on commit 07068af

Please sign in to comment.