diff --git a/include/fintamath/expressions/Expression.hpp b/include/fintamath/expressions/Expression.hpp index 81e9270e5..d6e0aa6b1 100644 --- a/include/fintamath/expressions/Expression.hpp +++ b/include/fintamath/expressions/Expression.hpp @@ -43,7 +43,7 @@ class Expression : public IExpressionCRTP { void setChildren(const ArgumentsPtrVector &childVect) override; - void setVariables(const std::vector &vars, const ArgumentsPtrVector &vals) override; + void setVariables(const std::vector> &varsToVals) override; void setVariable(const Variable &var, const Expression &val); diff --git a/include/fintamath/expressions/IExpression.hpp b/include/fintamath/expressions/IExpression.hpp index 70798d96d..c3ca7fc1a 100644 --- a/include/fintamath/expressions/IExpression.hpp +++ b/include/fintamath/expressions/IExpression.hpp @@ -19,7 +19,7 @@ class IExpression : public IArithmetic { std::vector getVariables() const; - virtual void setVariables(const std::vector &vars, const ArgumentsPtrVector &vals); + virtual void setVariables(const std::vector> &varsToVals); std::unique_ptr toMinimalObject() const final; diff --git a/src/fintamath/expressions/Expression.cpp b/src/fintamath/expressions/Expression.cpp index b02395349..ce0692ed0 100644 --- a/src/fintamath/expressions/Expression.cpp +++ b/src/fintamath/expressions/Expression.cpp @@ -517,13 +517,13 @@ void Expression::setChildren(const ArgumentsPtrVector &childVect) { child = childVect.front()->toMinimalObject(); } -void Expression::setVariables(const std::vector &vars, const ArgumentsPtrVector &vals) { - IExpression::setVariables(vars, vals); +void Expression::setVariables(const std::vector> &varsToVals) { + IExpression::setVariables(varsToVals); simplifyChild(child); } void Expression::setVariable(const Variable &var, const Expression &val) { - setVariables({var}, ArgumentsPtrVector{val.getChildren().front()}); + setVariables({{var, val.child}}); } Expression operator+(const Variable &lhs, const Variable &rhs) { diff --git a/src/fintamath/expressions/IExpression.cpp b/src/fintamath/expressions/IExpression.cpp index 333c29621..45beaa9f4 100644 --- a/src/fintamath/expressions/IExpression.cpp +++ b/src/fintamath/expressions/IExpression.cpp @@ -32,25 +32,27 @@ std::vector IExpression::getVariables() const { return vars; } -void IExpression::setVariables(const std::vector &vars, const ArgumentsPtrVector &vals) { +void IExpression::setVariables(const std::vector> &varsToVals) { auto children = getChildren(); ArgumentsPtrVector newChildren; for (auto &child : children) { if (std::shared_ptr exprChild = cast(child->clone())) { - exprChild->setVariables(vars, vals); + exprChild->setVariables(varsToVals); newChildren.emplace_back(exprChild); continue; } bool isAdded = false; - for (size_t i = 0; i < vars.size(); i++) { - if (const auto varChild = cast(child); varChild && *varChild == vars[i]) { - newChildren.push_back(vals[i]->clone()); - isAdded = true; - break; + if (const auto varChild = cast(child)) { + for (const auto &varsToVal : varsToVals) { + if (*varChild == varsToVal.first) { + newChildren.push_back(varsToVal.second); + isAdded = true; + break; + } } } diff --git a/src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp b/src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp index c923cee00..65a71e83e 100644 --- a/src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp +++ b/src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp @@ -189,13 +189,22 @@ ArgumentsPtrVector solveQuadraticEquation(const ArgumentsPtrVector &coeffAtPow) // TODO: remove this try/catch when complex numbers will be implemented try { Expression firstRootValue = firstRoot; - firstRootValue.setVariables({c, b, a}, coeffAtPow); + firstRootValue.setVariables({ + {c, coeffAtPow[0]}, // + {b, coeffAtPow[1]}, // + {a, coeffAtPow[2]}, // + }); Expression secondRootValue = secondRoot; - secondRootValue.setVariables({c, b, a}, coeffAtPow); + secondRootValue.setVariables({ + {c, coeffAtPow[0]}, // + {b, coeffAtPow[1]}, // + {a, coeffAtPow[2]}, // + }); return {firstRootValue.getChildren().front(), secondRootValue.getChildren().front()}; } + catch (const UndefinedException &) { return {}; }