Skip to content

Commit

Permalink
Refactor ExpressionFunctionSolve and fix solve(x = 3^x)
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Jan 9, 2024
1 parent 6859efd commit 4160b7b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 84 deletions.
110 changes: 26 additions & 84 deletions src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@ namespace fintamath {

namespace {

std::shared_ptr<const INumber> getElementPower(const ArgumentPtr &elem, const Variable &var);
const size_t maxPower = 4;

std::shared_ptr<const INumber> getMulElementPower(const std::shared_ptr<const IExpression> &elem, const Variable &var);

ArgumentPtr getElementRate(const ArgumentPtr &elem, const Variable &var);

ArgumentPtrVector getVariableIntPowerRates(const ArgumentPtr &elem, const Variable &var);
ArgumentPtrVector getPolynomCoefficients(const ArgumentPtr &elem, const Variable &var);

ArgumentPtrVector solveCubicEquation(const ArgumentPtrVector &coeffAtPow);

Expand All @@ -45,7 +41,7 @@ Expression solve(const Expression &rhs) {
// TODO: remove this if when inequalities will be implemented
if (!is<Eqv>(compExpr->getFunction())) {
auto var = cast<Variable>(compExpr->getVariables().front());
ArgumentPtrVector powerRate = getVariableIntPowerRates(compExpr->getChildren().front(), var);
ArgumentPtrVector powerRate = getPolynomCoefficients(compExpr->getChildren().front(), var);

if (powerRate.size() == 2) {
compExpr->markAsSolution();
Expand All @@ -56,7 +52,7 @@ Expression solve(const Expression &rhs) {
}

auto var = cast<Variable>(compExpr->getVariables().front());
ArgumentPtrVector powerRates = getVariableIntPowerRates(compExpr->getChildren().front(), var);
ArgumentPtrVector powerRates = getPolynomCoefficients(compExpr->getChildren().front(), var);
ArgumentPtrVector roots;

switch (powerRates.size()) {
Expand All @@ -71,6 +67,7 @@ Expression solve(const Expression &rhs) {
break;
default:
roots = {};
break;
}

if (roots.empty()) {
Expand All @@ -94,66 +91,8 @@ Expression solve(const Expression &rhs) {

namespace {

std::shared_ptr<const INumber> getElementPower(const ArgumentPtr &elem, const Variable &var) {
if (const auto elemVar = cast<Variable>(elem); elemVar && *elemVar == var) {
return std::make_shared<Integer>(1);
}

if (const auto expr = cast<IExpression>(elem)) {
if (is<Mul>(expr->getFunction())) {
return getMulElementPower(expr, var);
}

if (is<Pow>(expr->getFunction())) {
if (const auto elemVar = cast<Variable>(expr->getChildren().front()); elemVar && *elemVar == var) {
return cast<INumber>(expr->getChildren().back());
}
}
}

return std::make_shared<Integer>(0);
}

std::shared_ptr<const INumber> getMulElementPower(const std::shared_ptr<const IExpression> &elem, const Variable &var) {
for (const auto &child : elem->getChildren()) {
if (auto powValue = getElementPower(child, var); *powValue != Integer(0)) {
return powValue;
}
}

return std::make_shared<Integer>(0);
}

ArgumentPtr getElementRate(const ArgumentPtr &elem, const Variable &var) {
if (const auto elemExpr = cast<IExpression>(elem)) {
if (is<Pow>(elemExpr->getFunction())) {
if (containsVariable(elemExpr, var)) {
return Integer(1).clone();
}

return elem;
}

if (is<Mul>(elemExpr->getFunction())) {
ArgumentPtrVector coeffs = {Integer(1).clone()};

for (const auto &child : elemExpr->getChildren()) {
coeffs.emplace_back(getElementRate(child, var));
}

return mulExpr(std::move(coeffs))->toMinimalObject();
}
}

if (const auto elemVar = cast<Variable>(elem); elemVar && var == *elemVar) {
return Integer(1).clone();
}

return elem;
}

ArgumentPtrVector getVariableIntPowerRates(const ArgumentPtr &elem, const Variable &var) {
ArgumentPtrVector powerRates;
ArgumentPtrVector getPolynomCoefficients(const ArgumentPtr &elem, const Variable &var) {
ArgumentPtrVector powers;
ArgumentPtrVector polynomVect;

if (const auto exprVal = cast<IExpression>(elem); exprVal && is<Add>(exprVal->getFunction())) {
Expand All @@ -164,28 +103,27 @@ ArgumentPtrVector getVariableIntPowerRates(const ArgumentPtr &elem, const Variab
}

for (const auto &polynomChild : polynomVect) {
ArgumentPtr rate = getElementRate(polynomChild, var);
std::shared_ptr<const INumber> power = getElementPower(polynomChild, var);
if (!containsVariable(polynomChild, var)) {
powers[0] = polynomChild;
continue;
}

if (auto intPow = cast<Integer>(power)) {
if (powerRates.size() < *intPow + 1) {
while (powerRates.size() != *intPow + 1) {
powerRates.emplace_back(Integer(0).clone());
}
}
auto [mulRate, mulValue] = splitMulExpr(polynomChild);
auto [powBase, powValue] = splitPowExpr(mulValue);
auto intPower = cast<Integer>(powValue);

powerRates[size_t(*intPow)] = rate;
}
else {
if (!intPower || *intPower > maxPower || *powBase != var) {
return {};
}
}

return powerRates;
}
while (powers.size() < *intPower + 1) {
powers.emplace_back(Integer(0).clone());
}

ArgumentPtrVector solveCubicEquation(const ArgumentPtrVector & /*coeffAtPow*/) {
return {};
powers[size_t(*intPower)] = mulRate;
}

return powers;
}

ArgumentPtrVector solveQuadraticEquation(const ArgumentPtrVector &coeffAtPow) {
Expand All @@ -202,6 +140,10 @@ ArgumentPtrVector solveQuadraticEquation(const ArgumentPtrVector &coeffAtPow) {
return {firstRoot.getChildren().front(), secondRoot.getChildren().front()};
}

ArgumentPtrVector solveCubicEquation(const ArgumentPtrVector & /*coeffAtPow*/) {
return {};
}

ArgumentPtrVector solveLinearEquation(const ArgumentPtrVector &coeffAtPow) {
return {negExpr(divExpr(coeffAtPow[0], coeffAtPow[1]))->toMinimalObject()};
}
Expand Down
1 change: 1 addition & 0 deletions tests/src/expressions/ExpressionFunctionsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ TEST(ExpressionFunctionsTests, solveTest) {
EXPECT_EQ(solve(Expression("x/y = 0")).toString(), "x = 0");
EXPECT_EQ(solve(Expression("x^2 - 2*sin(2) = 0")).toString(), "x = -sqrt(8 sin(2))/2 | x = sqrt(8 sin(2))/2");
EXPECT_EQ(solve(Expression("x = x sqrt(x)")).toString(), "x^(3/2) - x = 0");
EXPECT_EQ(solve(Expression("x = 3^x")).toString(), "x - 3^x = 0");

EXPECT_EQ(solve(Expression("E = Ey")).toString(), "y = 1");
EXPECT_EQ(solve(Expression("sin(4) = sin(4) y")).toString(), "y = 1");
Expand Down

0 comments on commit 4160b7b

Please sign in to comment.