diff --git a/include/fintamath/core/MathObjectUtils.hpp b/include/fintamath/core/MathObjectUtils.hpp index 80a6f3185..077aa95ed 100644 --- a/include/fintamath/core/MathObjectUtils.hpp +++ b/include/fintamath/core/MathObjectUtils.hpp @@ -129,4 +129,12 @@ inline std::shared_ptr cast(const std::shared_ptr &from) noexcept { return std::static_pointer_cast(from); } -} \ No newline at end of file +template +struct ToStringComparator { + template + bool operator()(const T &lhs, const T &rhs) const noexcept { + return Comparator{}(lhs.toString(), rhs.toString()); + } +}; + +} diff --git a/include/fintamath/expressions/IExpression.hpp b/include/fintamath/expressions/IExpression.hpp index dfa5e44ff..f7fb2bd48 100644 --- a/include/fintamath/expressions/IExpression.hpp +++ b/include/fintamath/expressions/IExpression.hpp @@ -2,12 +2,14 @@ #include #include +#include #include #include #include #include "fintamath/core/IMathObject.hpp" #include "fintamath/core/MathObjectClass.hpp" +#include "fintamath/core/MathObjectUtils.hpp" #include "fintamath/core/Parser.hpp" #include "fintamath/functions/FunctionArguments.hpp" #include "fintamath/functions/IFunction.hpp" @@ -20,6 +22,9 @@ namespace fintamath { class IExpression : public IMathObject { FINTAMATH_PARENT_CLASS_BODY(IExpression, IMathObject) +public: + using VariableSet = std::set>>; + public: virtual const std::shared_ptr &getFunction() const = 0; @@ -27,7 +32,7 @@ class IExpression : public IMathObject { virtual void setChildren(const ArgumentPtrVector &childVect) = 0; - std::vector getVariables() const; + VariableSet getVariables() const; virtual void setVariables(const std::vector> &varsToVals); diff --git a/src/fintamath/expressions/IExpression.cpp b/src/fintamath/expressions/IExpression.cpp index 9a863dac0..90bc34430 100644 --- a/src/fintamath/expressions/IExpression.cpp +++ b/src/fintamath/expressions/IExpression.cpp @@ -29,23 +29,19 @@ FINTAMATH_PARENT_CLASS_IMPLEMENTATION(IExpression) using namespace detail; -std::vector IExpression::getVariables() const { - std::vector vars; +IExpression::VariableSet IExpression::getVariables() const { + VariableSet vars; for (const auto &child : getChildren()) { if (auto var = cast(child)) { - vars.emplace_back(*var); + vars.emplace(*var); } else if (const auto childExpr = cast(child)) { - std::vector childVars = childExpr->getVariables(); - vars.insert(vars.end(), childVars.begin(), childVars.end()); + VariableSet childVars = childExpr->getVariables(); + vars.insert(childVars.begin(), childVars.end()); } } - std::ranges::sort(vars, std::less{}, &Variable::toString); - auto unique = std::ranges::unique(vars); - vars.erase(unique.begin(), unique.end()); - return vars; } diff --git a/src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp b/src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp index 361b48f94..3ba64b2f1 100644 --- a/src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp +++ b/src/fintamath/expressions/functions/ExpressionFunctionSolve.cpp @@ -44,7 +44,7 @@ Expression solve(const Expression &rhs) { // TODO: remove this if when inequalities will be implemented if (!is(compExpr->getFunction())) { - const auto var = cast(compExpr->getVariables().front()); + const auto var = cast(*compExpr->getVariables().begin()); const ArgumentPtrVector powerRate = getPolynomCoefficients(compExpr->getChildren().front(), var); if (powerRate.size() == 2) { @@ -55,7 +55,7 @@ Expression solve(const Expression &rhs) { return rhs; } - const auto var = cast(compExpr->getVariables().front()); + const auto var = cast(*compExpr->getVariables().begin()); const ArgumentPtrVector powerRates = getPolynomCoefficients(compExpr->getChildren().front(), var); ArgumentPtrVector roots; diff --git a/tests/src/core/MathObjectUtilsTests.cpp b/tests/src/core/MathObjectUtilsTests.cpp index de79b21ed..b76df3b2e 100644 --- a/tests/src/core/MathObjectUtilsTests.cpp +++ b/tests/src/core/MathObjectUtilsTests.cpp @@ -57,3 +57,7 @@ TEST(MathObjectUtilsTests, castTest) { EXPECT_TRUE(cast(std::const_pointer_cast(std::shared_ptr(i.clone())))); EXPECT_FALSE(cast(std::const_pointer_cast(std::shared_ptr(E().clone())))); } + +TEST(MathObjectUtilsTests, compareByToStringTest) { + // TODO: implement +} diff --git a/tests/src/expressions/IExpressionTests.cpp b/tests/src/expressions/IExpressionTests.cpp index b49aac62f..528160b08 100644 --- a/tests/src/expressions/IExpressionTests.cpp +++ b/tests/src/expressions/IExpressionTests.cpp @@ -1,8 +1,10 @@ +#include #include #include "fintamath/expressions/IExpression.hpp" #include "fintamath/expressions/Expression.hpp" +#include "fintamath/expressions/ExpressionParser.hpp" #include "fintamath/expressions/ExpressionUtils.hpp" #include "fintamath/functions/arithmetic/Add.hpp" #include "fintamath/functions/other/Factorial.hpp" @@ -84,11 +86,8 @@ TEST(IExpressionTests, setChildrenTest) { } TEST(IExpressionTests, getVariablesTest) { - const auto expr = cast(Expression("x^2+y^2+a").clone()); - const auto vars = expr->getVariables(); - EXPECT_EQ(vars[0].toString(), "a"); - EXPECT_EQ(vars[1].toString(), "x"); - EXPECT_EQ(vars[2].toString(), "y"); + const auto expr = cast(parseExpr("x^2+y^2+a")); + EXPECT_THAT(expr->getVariables(), testing::ElementsAre(Variable("a"), Variable("x"), Variable("y"))); // TODO: implement more tests }