Skip to content

Commit

Permalink
Refactor splitPowExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Jan 9, 2024
1 parent ecd6b26 commit 6859efd
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 45 deletions.
4 changes: 2 additions & 2 deletions src/fintamath/expressions/ExpressionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ std::pair<ArgumentPtr, ArgumentPtr> splitMulExpr(const ArgumentPtr &inChild, boo
std::pair<ArgumentPtr, ArgumentPtr> splitPowExpr(const ArgumentPtr &rhs) {
if (const auto &powExpr = cast<IExpression>(rhs); powExpr && is<Pow>(powExpr->getFunction())) {
const ArgumentPtrVector &powExprChildren = powExpr->getChildren();
return {powExprChildren[1], powExprChildren[0]};
return {powExprChildren[0], powExprChildren[1]};
}

return {one, rhs};
return {rhs, one};
}

std::pair<ArgumentPtr, ArgumentPtr> splitRational(const ArgumentPtr &arg) {
Expand Down
32 changes: 16 additions & 16 deletions src/fintamath/expressions/binary/DivExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,30 +404,30 @@ ArgumentPtr DivExpression::equalSimplify(const IFunction & /*func*/, const Argum
}

ArgumentPtr DivExpression::powSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
auto [lhsChildRate, lhsChildValue] = splitPowExpr(lhs);
auto [rhsChildRate, rhsChildValue] = splitPowExpr(rhs);
auto [lhsChildBase, lhsChildRate] = splitPowExpr(lhs);
auto [rhsChildBase, rhsChildRate] = splitPowExpr(rhs);

auto lhsChildValueNum = cast<INumber>(lhsChildValue);
auto lhsChildBaseNum = cast<INumber>(rhsChildBase);
auto lhsChildRateNum = cast<INumber>(lhsChildRate);
auto rhsChildValueNum = cast<INumber>(rhsChildValue);
auto rhsChildBaseNum = cast<INumber>(rhsChildBase);
auto rhsChildRateNum = cast<INumber>(rhsChildRate);

if (lhsChildValueNum && rhsChildValueNum &&
if (lhsChildBaseNum && rhsChildBaseNum &&
lhsChildRateNum && rhsChildRateNum &&
*lhsChildRateNum < *rhsChildRateNum) {

return {};
}

if (*lhsChildValue == *rhsChildValue && !containsInfinity(lhsChildValue)) {
return powExpr(lhsChildValue, addExpr(lhsChildRate, negExpr(rhsChildRate)));
if (*lhsChildBase == *rhsChildBase && !containsInfinity(rhsChildBase)) {
return powExpr(rhsChildBase, addExpr(lhsChildRate, negExpr(rhsChildRate)));
}

if (rhsChildValueNum) {
if (rhsChildBaseNum) {
if (const auto rhsChildRateRat = cast<Rational>(rhsChildRate)) {
ArgumentPtr numeratorPow = Pow()(*rhsChildValue, 1 - (*rhsChildRateRat));
ArgumentPtr numeratorPow = Pow()(*rhsChildBase, 1 - (*rhsChildRateRat));
ArgumentPtr numerator = mulExpr(lhs, numeratorPow);
return divExpr(numerator, rhsChildValue);
return divExpr(numerator, rhsChildBase);
}
}

Expand Down Expand Up @@ -519,15 +519,15 @@ ArgumentPtr DivExpression::nestedNumeratorRationalSimplify(const ArgumentPtrVect
}

ArgumentPtr DivExpression::tanCotSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
auto [lhsChildRate, lhsChildValue] = splitPowExpr(lhs);
auto [rhsChildRate, rhsChildValue] = splitPowExpr(rhs);
auto [lhsChildBase, lhsChildRate] = splitPowExpr(lhs);
auto [rhsChildBase, rhsChildRate] = splitPowExpr(rhs);

if (*lhsChildRate != *rhsChildRate) {
return {};
}

auto lhsChildValueExpr = cast<IExpression>(lhsChildValue);
auto rhsChildValueExpr = cast<IExpression>(rhsChildValue);
auto lhsChildValueExpr = cast<IExpression>(lhsChildBase);
auto rhsChildValueExpr = cast<IExpression>(rhsChildBase);

if (lhsChildValueExpr && rhsChildValueExpr &&
*lhsChildValueExpr->getChildren().front() == *rhsChildValueExpr->getChildren().front()) {
Expand All @@ -553,9 +553,9 @@ ArgumentPtr DivExpression::tanCotSimplify(const IFunction & /*func*/, const Argu
}

ArgumentPtr DivExpression::secCscSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
auto [rhsChildRate, rhsChildValue] = splitPowExpr(rhs);
auto [rhsChildBase, rhsChildRate] = splitPowExpr(rhs);

if (auto rhsChildValueExpr = cast<IExpression>(rhsChildValue)) {
if (auto rhsChildValueExpr = cast<IExpression>(rhsChildBase)) {
if (is<Sin>(rhsChildValueExpr->getFunction())) {
return mulExpr(lhs, powExpr(cscExpr(*rhsChildValueExpr->getChildren().front()), rhsChildRate));
}
Expand Down
8 changes: 4 additions & 4 deletions src/fintamath/expressions/binary/PowExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ std::vector<Integer> PowExpression::getPartition(Integer bitNumber, const Intege

// Uses multinomial theorem for exponentiation of sum.
// https://en.wikipedia.org/wiki/Multinomial_theorem
ArgumentPtr PowExpression::sumPolynomSimplify(const ArgumentPtr &expr, const Integer &powValue) {
ArgumentPtr PowExpression::sumPolynomSimplify(const ArgumentPtr &expr, const Integer &power) {
auto sumExpr = cast<IExpression>(expr);
ArgumentPtrVector polynom;

Expand All @@ -169,8 +169,8 @@ ArgumentPtr PowExpression::sumPolynomSimplify(const ArgumentPtr &expr, const Int
}

size_t variableCount = polynom.size();
Integer bitNumber = generateFirstNum(powValue);
Integer combins = combinations(powValue + variableCount - 1, powValue);
Integer bitNumber = generateFirstNum(power);
Integer combins = combinations(power + variableCount - 1, power);

ArgumentPtrVector newChildren;

Expand All @@ -179,7 +179,7 @@ ArgumentPtr PowExpression::sumPolynomSimplify(const ArgumentPtr &expr, const Int
bitNumber = generateNextNumber(bitNumber);

ArgumentPtrVector mulExprChildren;
mulExprChildren.emplace_back(multinomialCoefficient(powValue, vectOfPows).clone());
mulExprChildren.emplace_back(multinomialCoefficient(power, vectOfPows).clone());

for (auto i : std::views::iota(0U, variableCount)) {
ArgumentPtr powExprChild = powExpr(polynom[i], vectOfPows[i].clone());
Expand Down
2 changes: 1 addition & 1 deletion src/fintamath/expressions/binary/PowExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class PowExpression : public IBinaryExpressionCRTP<PowExpression> {
SimplifyFunctionVector getFunctionsForPostSimplify() const override;

private:
static ArgumentPtr sumPolynomSimplify(const ArgumentPtr &sumExpr, const Integer &powValue);
static ArgumentPtr sumPolynomSimplify(const ArgumentPtr &sumExpr, const Integer &power);

static Integer generateNextNumber(const Integer &n);

Expand Down
30 changes: 15 additions & 15 deletions src/fintamath/expressions/polynomial/AddExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,32 +325,32 @@ ArgumentPtr AddExpression::trigSimplify(const IFunction & /*func*/, const Argume
auto [lhsMulRate, lhsMulValue] = splitMulExpr(lhs, false);
auto [rhsMulRate, rhsMulValue] = splitMulExpr(rhs, false);

auto [lhsPowRate, lhsPowValue] = splitPowExpr(lhsMulValue);
auto [rhsPowRate, rhsPowValue] = splitPowExpr(rhsMulValue);
auto [lhsPowBase, lhsPowRate] = splitPowExpr(lhsMulValue);
auto [rhsPowBase, rhsPowRate] = splitPowExpr(rhsMulValue);

auto lhsPowValueExpr = cast<IExpression>(lhsPowValue);
auto rhsPowValueExpr = cast<IExpression>(rhsPowValue);
auto lhsPowBaseExpr = cast<IExpression>(lhsPowBase);
auto rhsPowBaseExpr = cast<IExpression>(rhsPowBase);

if (!lhsPowValueExpr || *lhsPowRate != Integer(2)) {
if (!lhsPowBaseExpr || *lhsPowRate != Integer(2)) {
return {};
}

if (containsInfinity(lhsPowValue) || containsInfinity(rhsPowValue)) {
if (containsInfinity(lhsPowBase) || containsInfinity(rhsPowBase)) {
return {};
}

auto lhsPowValueChild = lhsPowValueExpr->getChildren().front();
auto lhsPowBaseChild = lhsPowBaseExpr->getChildren().front();

auto lhsMulRateNum = cast<INumber>(lhsMulRate);

if (rhsPowValueExpr && *rhsPowRate == Integer(2)) {
if (!is<Sin>(lhsPowValueExpr->getFunction()) || !is<Cos>(rhsPowValueExpr->getFunction())) {
if (rhsPowBaseExpr && *rhsPowRate == Integer(2)) {
if (!is<Sin>(lhsPowBaseExpr->getFunction()) || !is<Cos>(rhsPowBaseExpr->getFunction())) {
return {};
}

auto rhsPowValueChild = rhsPowValueExpr->getChildren().front();
auto rhsPowBaseChild = rhsPowBaseExpr->getChildren().front();

if (*lhsPowValueChild != *rhsPowValueChild) {
if (*lhsPowBaseChild != *rhsPowBaseChild) {
return {};
}

Expand All @@ -363,7 +363,7 @@ ArgumentPtr AddExpression::trigSimplify(const IFunction & /*func*/, const Argume
if (lhsMulRateNum && rhsMulRateNum && *(*lhsMulRateNum + *rhsMulRateNum) == Integer(0)) {
ArgumentPtr res = cosExpr(
mulExpr(
lhsPowValueExpr->getChildren().front(),
lhsPowBaseExpr->getChildren().front(),
Integer(2).clone()));

return mulExpr(rhsMulRateNum, res);
Expand All @@ -375,12 +375,12 @@ ArgumentPtr AddExpression::trigSimplify(const IFunction & /*func*/, const Argume
auto rhsNum = cast<INumber>(rhs);

if (lhsMulRateNum && rhsNum && *(*lhsMulRateNum + *rhsNum) == Integer(0)) {
ArgumentPtr res = lhsPowValueExpr->getChildren().front();
ArgumentPtr res = lhsPowBaseExpr->getChildren().front();

if (is<Sin>(lhsPowValueExpr->getFunction())) {
if (is<Sin>(lhsPowBaseExpr->getFunction())) {
res = cosExpr(res);
}
else if (is<Cos>(lhsPowValueExpr->getFunction())) {
else if (is<Cos>(lhsPowBaseExpr->getFunction())) {
res = sinExpr(res);
}
else {
Expand Down
14 changes: 7 additions & 7 deletions src/fintamath/expressions/polynomial/MulExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,16 @@ ArgumentPtr MulExpression::polynomSimplify(const IFunction & /*func*/, const Arg
}

ArgumentPtr MulExpression::powSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
auto [lhsChildRate, lhsChildValue] = splitPowExpr(lhs);
auto [rhsChildRate, rhsChildValue] = splitPowExpr(rhs);
auto [lhsChildBase, lhsChildRate] = splitPowExpr(lhs);
auto [rhsChildBase, rhsChildRate] = splitPowExpr(rhs);

if (*lhsChildValue == *rhsChildValue) {
if (*lhsChildBase == *rhsChildBase) {
ArgumentPtr ratesSum = addExpr(lhsChildRate, rhsChildRate);
return powExpr(lhsChildValue, ratesSum);
return powExpr(lhsChildBase, ratesSum);
}

auto lhsChildValueNum = cast<INumber>(lhsChildValue);
auto rhsChildValueNum = cast<INumber>(rhsChildValue);
auto lhsChildValueNum = cast<INumber>(lhsChildBase);
auto rhsChildValueNum = cast<INumber>(rhsChildBase);

if (lhsChildValueNum &&
rhsChildValueNum &&
Expand All @@ -280,7 +280,7 @@ ArgumentPtr MulExpression::powSimplify(const IFunction & /*func*/, const Argumen
*lhsChildRate == *rhsChildRate &&
*rhsChildRate != Integer(1)) {

ArgumentPtr valuesMul = mulExpr(lhsChildValue, rhsChildValue);
ArgumentPtr valuesMul = mulExpr(lhsChildBase, rhsChildBase);
return powExpr(valuesMul, lhsChildRate);
}

Expand Down

0 comments on commit 6859efd

Please sign in to comment.