Skip to content

Commit

Permalink
Improve DivExpression simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Jul 12, 2023
1 parent 31da863 commit 2cd6d62
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 20 deletions.
71 changes: 68 additions & 3 deletions src/fintamath/expressions/binary/DivExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "fintamath/functions/arithmetic/Neg.hpp"
#include "fintamath/functions/arithmetic/Sub.hpp"
#include "fintamath/functions/powers/Pow.hpp"
#include "fintamath/numbers/IntegerFunctions.hpp"
#include "fintamath/numbers/Rational.hpp"

namespace fintamath {
Expand All @@ -33,6 +34,7 @@ DivExpression::SimplifyFunctionsVector DivExpression::getFunctionsForPostSimplif
&DivExpression::divSimplify, //
&DivExpression::mulSimplify, //
&DivExpression::nestedRationalsSimplify, //
&DivExpression::gcdSimplify, //
&DivExpression::sumSimplify, //
};
return simplifyFunctions;
Expand Down Expand Up @@ -209,18 +211,22 @@ bool DivExpression::isNeg(const ArgumentPtr &expr) {
}

ArgumentPtr DivExpression::sumSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
if (auto [result, remainder] = sumMulSimplify(lhs, rhs); result) {
if (auto [result, remainder] = mulSumSimplify(lhs, rhs); result) {
return makeExpr(Add(), result, remainder);
}

if (auto [result, remainder] = mulSumSimplify(lhs, rhs); result) {
if (auto [result, remainder] = sumMulSimplify(lhs, rhs); result) {
return makeExpr(Add(), result, remainder);
}

if (auto [result, remainder] = sumSumSimplify(lhs, rhs); result) {
return makeExpr(Add(), result, remainder);
}

if (auto [result, remainder] = sumSumSimplify(rhs, lhs); result && !is<IExpression>(remainder)) {
return makeExpr(Div(), Integer(1).clone(), makeExpr(Add(), result, remainder));
}

return {};
}

Expand Down Expand Up @@ -426,6 +432,66 @@ ArgumentPtr DivExpression::nestedRationalsSimplify(const IFunction & /*func*/, c
return {};
}

ArgumentPtr DivExpression::gcdSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
ArgumentsPtrVector lhsChildren;
ArgumentsPtrVector rhsChildren;

if (const auto lhsExpr = cast<IExpression>(lhs); lhsExpr && is<Add>(lhsExpr->getFunction())) {
lhsChildren = lhsExpr->getChildren();
}
else {
lhsChildren = {lhs};
}

if (const auto rhsExpr = cast<IExpression>(rhs); rhsExpr && is<Add>(rhsExpr->getFunction())) {
rhsChildren = rhsExpr->getChildren();
}
else {
rhsChildren = {rhs};
}

Integer lhsGcdNum = getGcd(lhsChildren);
Integer rhsGcdNum = getGcd(rhsChildren);

if (lhsGcdNum <= 1 || rhsGcdNum <= 1) {
return {};
}

Integer gcdNum = gcd(lhsGcdNum, rhsGcdNum);

if (gcdNum <= 1) {
return {};
}

ArgumentPtr numerator = makeExpr(Div(), lhs, gcdNum.clone());
simplifyChild(numerator);

ArgumentPtr denominator = makeExpr(Div(), rhs, gcdNum.clone());
simplifyChild(denominator);

return makeExpr(Div(), numerator, denominator);
}

Integer DivExpression::getGcd(ArgumentsPtrVector &lhsChildren) {
Integer gcdNum;

for (auto child : lhsChildren) {
if (const auto childExpr = cast<IExpression>(child); childExpr && is<Mul>(childExpr->getFunction())) {
child = childExpr->getChildren().front();
}

if (const auto childInt = cast<Integer>(child)) {
Integer childIntAbs = abs(*childInt);
gcdNum = gcdNum != 0 ? gcd(gcdNum, childIntAbs) : childIntAbs;
}
else {
return 1;
}
}

return gcdNum;
}

ArgumentPtr DivExpression::nestedRationalsInNumeratorSimplify(const ArgumentsPtrVector &lhsChildren,
const ArgumentPtr &rhs) {

Expand Down Expand Up @@ -478,5 +544,4 @@ ArgumentPtr DivExpression::nestedRationalsInDenominatorSimplify(const ArgumentPt

return {};
}

}
4 changes: 4 additions & 0 deletions src/fintamath/expressions/binary/DivExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class DivExpression : public IBinaryExpressionCRTP<DivExpression> {

static ArgumentPtr nestedRationalsSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr gcdSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr nestedRationalsInNumeratorSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs);

static ArgumentPtr nestedRationalsInDenominatorSimplify(const ArgumentPtr &lhs,
Expand All @@ -51,6 +53,8 @@ class DivExpression : public IBinaryExpressionCRTP<DivExpression> {
static ArgumentPtr addRatesToValue(const ArgumentsPtrVector &rates, const ArgumentPtr &value);

static bool isNeg(const ArgumentPtr &expr);

static Integer getGcd(ArgumentsPtrVector &children);
};

}
36 changes: 19 additions & 17 deletions tests/src/expressions/ExpressionTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,6 @@ TEST(ExpressionTests, stringConstructorTest) {
EXPECT_EQ(Expression("(a+b)^3").toString(), "a^3 + 3 a^2 b + 3 a b^2 + b^3");
EXPECT_EQ(Expression("1*(a+b)^3").toString(), "a^3 + 3 a^2 b + 3 a b^2 + b^3");
EXPECT_EQ(Expression("(a+b)^4").toString(), "a^4 + 4 a^3 b + 6 a^2 b^2 + 4 a b^3 + b^4");
EXPECT_EQ(Expression("(a+3)/(b+2)").toString(), "(a + 3)/(b + 2)");
EXPECT_EQ(Expression("b/a*(a+3)/(b+2)").toString(), "1 + (3 b)/(a b + 2 a) - 2/(b + 2)");
EXPECT_EQ(Expression("(5+b)/a*(a+3)/(b+2)").toString(),
"1 - 30/(a b^2 + 2 a b) - 6/(a b + 2 a) + 15/(a b) + 3/a + 3/(b + 2)");
EXPECT_EQ(Expression("(a+b)*(a+b)/(a+b)").toString(), "a + b");
EXPECT_EQ(Expression("(a+b)*(a+b)*(1/(a+b))").toString(), "a + b");
EXPECT_EQ(Expression("(x^2+2x+1)/(x+1)").toString(), "x + 1");
EXPECT_EQ(Expression("1*(a+b)*1").toString(), "a + b");
EXPECT_EQ(Expression("-1*(a+b)*1").toString(), "-a - b");
EXPECT_EQ(Expression("1*(a+b)*-1").toString(), "-a - b");
Expand Down Expand Up @@ -302,6 +295,7 @@ TEST(ExpressionTests, stringConstructorTest) {
EXPECT_EQ(Expression("c * 2^(a + 2) + b^(a + 2)").toString(), "b^(a + 2) + 2^(a + 2) c");
EXPECT_EQ(Expression("2^(a + 2) * b^(a + 2)").toString(), "b^(a + 2) 2^(a + 2)");

EXPECT_EQ(Expression("-9 / (3x+3)").toString(), "-3/(x + 1)");
EXPECT_EQ(Expression("(4x^2 - 5x - 21) / (x - 3)").toString(), "4 x + 7");
EXPECT_EQ(Expression("(3x^3 - 5x^2 + 10x - 3) / (3x + 1)").toString(), "x^2 - 2 x + 4 - 7/(3 x + 1)");
EXPECT_EQ(Expression("(2x^3 - 9x^2 + 15) / (2x - 5)").toString(), "x^2 - 2 x - 5 - 10/(2 x - 5)");
Expand All @@ -314,15 +308,17 @@ TEST(ExpressionTests, stringConstructorTest) {
EXPECT_EQ(Expression("(3x^6 + 5x^5 - 2x^4 + 4x^3 + x^2 + 3x - 5) / (x^4 + 3x^2 - 2)").toString(),
"3 x^2 + 5 x - 11 + (-11 x^3 + 40 x^2 + 13 x - 27)/(x^4 + 3 x^2 - 2)");
EXPECT_EQ(Expression("(6x^8 - 7x^6 + 9x^4 - 4x^2 + 8) / (2x^3 - x^2 + 3x - 1)").toString(),
"3 x^5 + 3/2 x^4 - 29/4 x^3 - 35/8 x^2 + 223/16 x + 317/32 + (-1289 x^2 - 505 x + 573)/(64 x^3 - 32 x^2 + "
"96 x - 32)");
"3 x^5 + (3 x^4)/2 - (29 x^3)/4 - (35 x^2)/8 + (223 x)/16 + 317/32 + (-1289 x^2 - 505 x + 573)/(64 x^3 - "
"32 x^2 + 96 x - 32)");
EXPECT_EQ(Expression("(2 a^3 + 5 a^2 b + 4 a b^2 + b^3)/(25 a^2 + 40 a b + 15 b^2)").toString(),
"2/25 a + 9/125 b + (-2 a b^2 - 2 b^3)/(625 a^2 + 1000 a b + 375 b^2)");
"(2 a)/25 + (9 b)/125 + (-2 a b^2 - 2 b^3)/(625 a^2 + 1000 a b + 375 b^2)");
EXPECT_EQ(Expression("(25 a^2 + 40 a b + 15 b^2)/(2 a^3 + 5 a^2 b + 4 a b^2 + b^3)").toString(),
"(2 x + 3 y)/(a + b) + 5/(2 a + b)");
"(25 a^2 + 40 a b + 15 b^2)/(2 a^3 + 5 a^2 b + 4 a b^2 + b^3)");
EXPECT_EQ(Expression("(x^2 + 2x + 1)/(x^3 + 3x^2 + 3x + 1)").toString(), "1/(x + 1)");
EXPECT_EQ(Expression("5/(a+b) + 5/(2a+b) + 5/(a+b)").toString(), "5/(2 a + b) + 10/(a + b)");
EXPECT_EQ(Expression("(x+y)/(a+b) + 5/(2a+b) + (x+2y)/(a+b)").toString(), "(2 x + 3 y)/(a + b) + 5/(2 a + b)");
EXPECT_EQ(Expression("5/(a+b) + 5/(2a+b) + 5/(a+b)").toString(),
"(25 a^2 + 40 a b + 15 b^2)/(2 a^3 + 5 a^2 b + 4 a b^2 + b^3)");
EXPECT_EQ(Expression("(x+y)/(a+b) + 5/(2a+b) + (x+2y)/(a+b)").toString(),
"(4 a x + 6 a y + 5 a + 2 b x + 3 b y + 5 b)/(2 a^2 + 3 a b + b^2)");
EXPECT_EQ(Expression("(a/b)(c/d)").toString(), "(a c)/(b d)");
EXPECT_EQ(Expression("(ab/2)(ad/3)").toString(), "(a^2 b d)/6");
EXPECT_EQ(Expression("(-a)(-b)").toString(), "a b");
Expand All @@ -339,15 +335,21 @@ TEST(ExpressionTests, stringConstructorTest) {
EXPECT_EQ(Expression("(3 x + 5/9)/(2y - 9/x + 3/2 x + 1/2 + 2 y / x)").toString(),
"2 + (-72 x y - 8 x - 72 y + 324)/(27 x^2 + 36 x y + 9 x + 36 y - 162)");
EXPECT_EQ(Expression("(a/x + b/(y+3/r)/4)/(3+t/5)").toString(),
"(5 a)/(t x + 15 x) + (25 b r)/(20 r t y + 300 r y + 60 t + 900)");
"(20 a r y + 60 a + 5 b r x)/(4 r t x y + 60 r x y + 12 t x + 180 x)");
EXPECT_EQ(Expression("(x/a - (b+5)/(y-8/(12 y))/4)/(8-a/5)").toString(),
"-(300 b y + 1500 y)/(-240 a y^2 + 160 a + 9600 y^2 - 6400) + (5 x)/(-a^2 + 40 a)");
"(-a b - 5 a + 4 x y + (-8 x)/(3 y))/(-(4 a^2 y)/5 + 32 a y + (8 a^2)/(15 y) + (-64 a)/(3 y))");
EXPECT_EQ(Expression("(a + b + c^2) / ((a + b + c^3) / (5/2 * (a + b) / (3/b + c/2)))").toString(),
"5 c + (5 a^2 b + 10 a b^2 - 30 a c + 5 b^3 + 180)/(a b c + 6 a + b^2 c + b c^4 + 6 b + 6 c^3) + (-5 c^4 - "
"30)/(a + b + c^3)");
"5 c + (5 a^2 b + 10 a b^2 - 30 a c + 5 b^3 - 5 b c^5 - 30 b c - 30 c^4)/(a b c + 6 a + b^2 c + b c^4 + 6 "
"b + 6 c^3)");
EXPECT_EQ(Expression("((2xy)/(x^2 - y^2) + (x - y)/(2x + 2y)) * (2x)/(x + y) + y/(y - x)").toString(), "1");
EXPECT_EQ(Expression("y/(x - y) - (x ^3 - xy ^2)/(x ^2 + y ^2) * (x/((x - y) ^2) - y/(x ^2 - y ^2))").toString(),
"-1");
EXPECT_EQ(Expression("(a+3)/(b+2)").toString(), "(a + 3)/(b + 2)");
EXPECT_EQ(Expression("b/a*(a+3)/(b+2)").toString(), "1 + (-2 a + 3 b)/(a b + 2 a)");
EXPECT_EQ(Expression("(5+b)/a*(a+3)/(b+2)").toString(), "1 + (3 a + 3 b + 15)/(a b + 2 a)");
EXPECT_EQ(Expression("(a+b)*(a+b)/(a+b)").toString(), "a + b");
EXPECT_EQ(Expression("(a+b)*(a+b)*(1/(a+b))").toString(), "a + b");
EXPECT_EQ(Expression("(x^2+2x+1)/(x+1)").toString(), "x + 1");

// TODO! implement this
// EXPECT_EQ(Expression("(x/y)^2").toString(), "(x^2)/(y^2)");
Expand Down

0 comments on commit 2cd6d62

Please sign in to comment.