Skip to content

Commit

Permalink
Return set in IExpression::getVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Apr 8, 2024
1 parent ec773c1 commit 9e7ae40
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 18 deletions.
10 changes: 9 additions & 1 deletion include/fintamath/core/MathObjectUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,12 @@ inline std::shared_ptr<To> cast(const std::shared_ptr<From> &from) noexcept {
return std::static_pointer_cast<To>(from);
}

}
template <typename Comparator>
struct ToStringComparator {
template <typename T>
bool operator()(const T &lhs, const T &rhs) const noexcept {
return Comparator{}(lhs.toString(), rhs.toString());
}
};

}
7 changes: 6 additions & 1 deletion include/fintamath/expressions/IExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

#include <concepts>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "fintamath/core/IMathObject.hpp"
#include "fintamath/core/MathObjectClass.hpp"
#include "fintamath/core/MathObjectUtils.hpp"
#include "fintamath/core/Parser.hpp"
#include "fintamath/functions/FunctionArguments.hpp"
#include "fintamath/functions/IFunction.hpp"
Expand All @@ -20,14 +22,17 @@ namespace fintamath {
class IExpression : public IMathObject {
FINTAMATH_PARENT_CLASS_BODY(IExpression, IMathObject)

public:
using VariableSet = std::set<Variable, ToStringComparator<std::less<>>>;

public:
virtual const std::shared_ptr<IFunction> &getFunction() const = 0;

virtual const ArgumentPtrVector &getChildren() const = 0;

virtual void setChildren(const ArgumentPtrVector &childVect) = 0;

std::vector<Variable> getVariables() const;
VariableSet getVariables() const;

virtual void setVariables(const std::vector<std::pair<Variable, ArgumentPtr>> &varsToVals);

Expand Down
14 changes: 5 additions & 9 deletions src/fintamath/expressions/IExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,19 @@ FINTAMATH_PARENT_CLASS_IMPLEMENTATION(IExpression)

using namespace detail;

std::vector<Variable> IExpression::getVariables() const {
std::vector<Variable> vars;
IExpression::VariableSet IExpression::getVariables() const {
VariableSet vars;

for (const auto &child : getChildren()) {
if (auto var = cast<Variable>(child)) {
vars.emplace_back(*var);
vars.emplace(*var);
}
else if (const auto childExpr = cast<IExpression>(child)) {
std::vector<Variable> childVars = childExpr->getVariables();
vars.insert(vars.end(), childVars.begin(), childVars.end());
VariableSet childVars = childExpr->getVariables();
vars.insert(childVars.begin(), childVars.end());
}
}

std::ranges::sort(vars, std::less{}, &Variable::toString);
auto unique = std::ranges::unique(vars);
vars.erase(unique.begin(), unique.end());

return vars;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Expression solve(const Expression &rhs) {

// TODO: remove this if when inequalities will be implemented
if (!is<Eqv>(compExpr->getFunction())) {
const auto var = cast<Variable>(compExpr->getVariables().front());
const auto var = cast<Variable>(*compExpr->getVariables().begin());
const ArgumentPtrVector powerRate = getPolynomCoefficients(compExpr->getChildren().front(), var);

if (powerRate.size() == 2) {
Expand All @@ -55,7 +55,7 @@ Expression solve(const Expression &rhs) {
return rhs;
}

const auto var = cast<Variable>(compExpr->getVariables().front());
const auto var = cast<Variable>(*compExpr->getVariables().begin());
const ArgumentPtrVector powerRates = getPolynomCoefficients(compExpr->getChildren().front(), var);
ArgumentPtrVector roots;

Expand Down
4 changes: 4 additions & 0 deletions tests/src/core/MathObjectUtilsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ TEST(MathObjectUtilsTests, castTest) {
EXPECT_TRUE(cast<IArithmetic>(std::const_pointer_cast<const IMathObject>(std::shared_ptr(i.clone()))));
EXPECT_FALSE(cast<IArithmetic>(std::const_pointer_cast<const IMathObject>(std::shared_ptr(E().clone()))));
}

TEST(MathObjectUtilsTests, compareByToStringTest) {
// TODO: implement
}
9 changes: 4 additions & 5 deletions tests/src/expressions/IExpressionTests.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "fintamath/expressions/IExpression.hpp"

#include "fintamath/expressions/Expression.hpp"
#include "fintamath/expressions/ExpressionParser.hpp"
#include "fintamath/expressions/ExpressionUtils.hpp"
#include "fintamath/functions/arithmetic/Add.hpp"
#include "fintamath/functions/other/Factorial.hpp"
Expand Down Expand Up @@ -84,11 +86,8 @@ TEST(IExpressionTests, setChildrenTest) {
}

TEST(IExpressionTests, getVariablesTest) {
const auto expr = cast<IExpression>(Expression("x^2+y^2+a").clone());
const auto vars = expr->getVariables();
EXPECT_EQ(vars[0].toString(), "a");
EXPECT_EQ(vars[1].toString(), "x");
EXPECT_EQ(vars[2].toString(), "y");
const auto expr = cast<IExpression>(parseExpr("x^2+y^2+a"));
EXPECT_THAT(expr->getVariables(), testing::ElementsAre(Variable("a"), Variable("x"), Variable("y")));

// TODO: implement more tests
}
Expand Down

0 comments on commit 9e7ae40

Please sign in to comment.