Skip to content

Commit

Permalink
Refactor functions and makeExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Aug 7, 2023
1 parent fa9aa8e commit c4ff06b
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 30 deletions.
4 changes: 4 additions & 0 deletions include/fintamath/expressions/Expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ class Expression : public IExpressionCRTP<Expression> {

static void preciseRec(ArgumentPtr &arg, uint8_t precision);

friend std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentsPtrVector &args);

friend std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentsRefVector &args);

friend std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentsPtrVector &args);

friend std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentsPtrVector &args);

friend ArgumentPtr parseExpr(const std::string &str);
Expand Down
2 changes: 2 additions & 0 deletions include/fintamath/expressions/ExpressionUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ bool hasVariable(const std::shared_ptr<const IExpression> &expr, const Variable

std::vector<std::string> argumentVectorToStringVector(const ArgumentsPtrVector &args);

ArgumentsPtrVector argumentRefVectorToArgumentPtrVector(const ArgumentsRefVector &args);

}
4 changes: 4 additions & 0 deletions include/fintamath/functions/FunctionUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ class IFunction;

extern bool isExpression(const IMathObject &arg);

extern std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentsPtrVector &args);

extern std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentsRefVector &args);

extern std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentsPtrVector &args);

extern std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentsRefVector &args);

template <typename T, typename = std::enable_if_t<std::is_convertible_v<T, ArgumentPtr>>>
ArgumentPtr toArgumentPtr(T &arg) {
if constexpr (std::is_copy_constructible_v<T>) {
Expand Down
6 changes: 5 additions & 1 deletion include/fintamath/functions/IFunctionCRTP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ class IFunctionCRTP_ : public IFunction {
validateArgsSize(argsVect);

if (doArgsMatch(argsVect)) {
return call(argsVect);
if (auto res = call(argsVect)) {
return res;
}

return makeExpr(*this, argsVect);
}

return makeExprChecked(*this, argsVect);
Expand Down
18 changes: 10 additions & 8 deletions src/fintamath/expressions/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,17 +497,15 @@ void Expression::preciseRec(ArgumentPtr &arg, uint8_t precision) {
}
}

std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentsRefVector &args) {
ArgumentsPtrVector argsPtrVect;

for (const auto &arg : args) {
argsPtrVect.emplace_back(arg.get().clone());
}

Expression res(makeExpr(func, argsPtrVect));
std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentsPtrVector &args) {
Expression res(makeExpr(func, args));
return res.getChildren().front()->clone();
}

std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentsRefVector &args) {
return makeExprChecked(func, argumentRefVectorToArgumentPtrVector(args));
}

std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentsPtrVector &args) {
if (auto expr = Parser::parse(Expression::getExpressionMakers(), func.toString(), args)) {
return expr;
Expand All @@ -516,6 +514,10 @@ std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentsPtrV
return std::make_unique<FunctionExpression>(func, args);
}

std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentsRefVector &args) {
return makeExpr(func, argumentRefVectorToArgumentPtrVector(args));
}

void Expression::setChildren(const ArgumentsPtrVector &childVect) {
if (childVect.size() != 1) {
throw InvalidInputFunctionException("", argumentVectorToStringVector(childVect));
Expand Down
12 changes: 11 additions & 1 deletion src/fintamath/expressions/ExpressionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,18 @@ std::vector<std::string> argumentVectorToStringVector(const ArgumentsPtrVector &
return argStrings;
}

ArgumentsPtrVector argumentRefVectorToArgumentPtrVector(const ArgumentsRefVector &args) {
ArgumentsPtrVector argsPtrVect;

for (const auto &arg : args) {
argsPtrVect.emplace_back(arg.get().clone());
}

return argsPtrVect;
}

bool isExpression(const IMathObject &arg) {
return is<IExpression>(arg);
}

}
}
8 changes: 4 additions & 4 deletions src/fintamath/functions/other/Factorial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ std::unique_ptr<IMathObject> Factorial::multiFactorialSimpl(const INumber &lhs,
std::unique_ptr<IMathObject> Factorial::factorialSimpl(const Integer &rhs, size_t order) {
if (rhs < 0) {
if (order != 1) {
return makeExpr(Factorial(order), rhs);
return {};
}

return ComplexInf().clone();
Expand All @@ -53,22 +53,22 @@ std::unique_ptr<IMathObject> Factorial::factorialSimpl(const Rational &rhs, size
}

if (order != 1) {
return makeExpr(Factorial(order), rhs);
return {};
}

return factorialSimpl(Real(rhs), order);
}

std::unique_ptr<IMathObject> Factorial::factorialSimpl(const Real &rhs, size_t order) {
if (order != 1) {
return makeExpr(Factorial(order), rhs);
return {};
}

try {
return tgamma(rhs + 1).toMinimalObject();
}
catch (const UndefinedException &) {
return makeExpr(Factorial(order), rhs);
return {};
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/fintamath/functions/powers/Pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::unique_ptr<IMathObject> Pow::powSimpl(const Rational &lhs, const Rational &

if (lhs < Integer(0)) {
// TODO: complex numbers
return makeExpr(Pow(), lhs, rhs);
return {};
}

if (lhsDenominator == 1) {
Expand All @@ -94,7 +94,7 @@ std::unique_ptr<IMathObject> Pow::powSimpl(const Real &lhs, const Real &rhs) {
return pow(lhs, rhs).toMinimalObject();
}
catch (const UndefinedException &) {
return makeExpr(Pow(), lhs, rhs);
return {};
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/fintamath/functions/powers/Root.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::unique_ptr<IMathObject> Root::call(const ArgumentsRefVector &argsVect) cons
if (rhsInt > Integer(1)) {
if (lhs < Integer(0)) {
// TODO: complex numbers
return makeExpr(Root(), lhs, rhs);
return {};
}

return multiRootSimpl(lhs, rhsInt);
Expand Down
4 changes: 4 additions & 0 deletions tests/src/expressions/ExpressionUtilsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ TEST(ExpressionUtilsTests, argumentVectorToStringVectorTest) {
// TODO: implement
}

TEST(ExpressionUtilsTests, argumentRefVectorToArgumentPtrVector) {
// TODO: implement
}

TEST(ExpressionUtilsTests, hasVariableTest) {
auto expr = std::make_shared<Expression>("cos(sin(a))");
EXPECT_TRUE(hasVariable(expr, Variable("a")));
Expand Down
63 changes: 52 additions & 11 deletions tests/src/functions/FunctionUtilsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,25 @@

using namespace fintamath;

TEST(FunctionUtilsTests, makeExpressionCheckedTest) {
TEST(FunctionUtilsTests, makeExpressionCheckedPtrsTest) {
ArgumentPtr one = std::make_unique<Integer>(1);
ArgumentPtr two = std::make_unique<Integer>(2);
auto expr1 = makeExprChecked(Add(), {one, two});
EXPECT_EQ(expr1->toString(), "3");
EXPECT_TRUE(is<INumber>(expr1));

ArgumentPtr var = std::make_unique<Variable>("a");
ArgumentPtr expr2 = makeExprChecked(Cos(), {var});
EXPECT_EQ(expr2->toString(), "cos(a)");
EXPECT_TRUE(is<IExpression>(expr2));

EXPECT_THROW(makeExprChecked(Mul(), ArgumentsPtrVector{var})->toString(), InvalidInputException);
EXPECT_THROW(makeExprChecked(Mul(), ArgumentsPtrVector{})->toString(), InvalidInputException);
EXPECT_THROW(makeExprChecked(Pow(), ArgumentsPtrVector{var})->toString(), InvalidInputException);
EXPECT_THROW(makeExprChecked(Pow(), ArgumentsPtrVector{})->toString(), InvalidInputException);
}

TEST(FunctionUtilsTests, makeExpressionCheckedRefsTest) {
Integer one = 1;
Integer two = 2;
auto expr1 = makeExprChecked(Add(), {one, two});
Expand All @@ -29,24 +47,47 @@ TEST(FunctionUtilsTests, makeExpressionCheckedTest) {
auto expr2 = makeExprChecked(Cos(), {var});
EXPECT_EQ(expr2->toString(), "cos(a)");
EXPECT_TRUE(is<IExpression>(expr2));

EXPECT_THROW(makeExprChecked(Mul(), ArgumentsRefVector{var})->toString(), InvalidInputException);
EXPECT_THROW(makeExprChecked(Mul(), ArgumentsRefVector{})->toString(), InvalidInputException);
EXPECT_THROW(makeExprChecked(Pow(), ArgumentsRefVector{var})->toString(), InvalidInputException);
EXPECT_THROW(makeExprChecked(Pow(), ArgumentsRefVector{})->toString(), InvalidInputException);
}

TEST(FunctionUtilsTests, makeExpressionTest) {
auto expr1 = makeExpr(Add(), {std::make_shared<Integer>(1), std::make_shared<Integer>(2)});
TEST(FunctionUtilsTests, makeExpressionPtrsTest) {
ArgumentPtr one = std::make_unique<Integer>(1);
ArgumentPtr two = std::make_unique<Integer>(2);
auto expr1 = makeExpr(Add(), {one, two});
EXPECT_EQ(expr1->toString(), "1 + 2");
EXPECT_TRUE(is<IExpression>(expr1));
EXPECT_FALSE(is<Expression>(expr1));
EXPECT_FALSE(is<INumber>(expr1));

auto var = std::make_shared<Variable>("a");
ArgumentPtr var = std::make_unique<Variable>("a");
ArgumentPtr expr2 = makeExpr(Cos(), {var});
EXPECT_EQ(expr2->toString(), "cos(a)");
EXPECT_TRUE(is<IExpression>(expr2));

EXPECT_THROW(makeExpr(Mul(), ArgumentsPtrVector{var})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Mul(), ArgumentsPtrVector{})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Pow(), ArgumentsPtrVector{var})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Pow(), ArgumentsPtrVector{})->toString(), InvalidInputException);
}

TEST(FunctionUtilsTests, makeExpressionRefsTest) {
Integer one = 1;
Integer two = 2;
auto expr1 = makeExpr(Add(), {one, two});
EXPECT_EQ(expr1->toString(), "1 + 2");
EXPECT_FALSE(is<INumber>(expr1));

Variable var("a");
auto expr2 = makeExpr(Cos(), {var});
EXPECT_EQ(expr2->toString(), "cos(a)");
EXPECT_TRUE(is<IExpression>(expr2));
EXPECT_FALSE(is<Expression>(expr2));

EXPECT_THROW(makeExpr(Mul(), {var})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Mul(), {})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Pow(), {var})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Pow(), {})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Mul(), ArgumentsRefVector{var})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Mul(), ArgumentsRefVector{})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Pow(), ArgumentsRefVector{var})->toString(), InvalidInputException);
EXPECT_THROW(makeExpr(Pow(), ArgumentsRefVector{})->toString(), InvalidInputException);
}

TEST(FunctionUtilsTests, makeExpressionCheckedAnyArgsTest) {
Expand Down
2 changes: 1 addition & 1 deletion tests/src/functions/powers/PowTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ TEST(PowTests, callTest) {
EXPECT_EQ(f(Real("2.2"), Real("0.5"))->toString(),
"1.48323969741913258974227948816014261219598086381950031974652465286876603686277");

EXPECT_EQ(f(Rational(-10), Rational("-1.5"))->toString(), "(-1/10)^(3/2)");
EXPECT_EQ(f(Rational(-10), Rational("-1.5"))->toString(), "(-10)^(-3/2)");

EXPECT_EQ(f(Integer(0), Integer(-1))->toString(), "ComplexInf");
EXPECT_EQ(f(Integer(0), Integer(-10))->toString(), "ComplexInf");
Expand Down
2 changes: 1 addition & 1 deletion tests/src/functions/powers/RootTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TEST(RootTests, callTest) {

EXPECT_EQ(f(Integer(-10), Integer(2))->toString(), "sqrt(-10)");
EXPECT_EQ(f(Rational(-9289, 10), Rational(2, 3))->toString(), "(-9289/10)^(3/2)");
EXPECT_EQ(f(Real(-9289), Rational(2, 3))->toString(), "(-9289.0)^1.5");
EXPECT_EQ(f(Real(-9289), Rational(2, 3))->toString(), "(-9289.0)^(3/2)");

EXPECT_THROW(f(), InvalidInputFunctionException);
EXPECT_THROW(f(Integer(1), Integer(1), Integer(1)), InvalidInputFunctionException);
Expand Down

0 comments on commit c4ff06b

Please sign in to comment.