diff --git a/include/fintamath/expressions/Expression.hpp b/include/fintamath/expressions/Expression.hpp index e5d3419e7..c5184a1b9 100644 --- a/include/fintamath/expressions/Expression.hpp +++ b/include/fintamath/expressions/Expression.hpp @@ -136,8 +136,12 @@ class Expression : public IExpressionCRTP { static void preciseRec(ArgumentPtr &arg, uint8_t precision); + friend std::unique_ptr makeExprChecked(const IFunction &func, const ArgumentsPtrVector &args); + friend std::unique_ptr makeExprChecked(const IFunction &func, const ArgumentsRefVector &args); + friend std::unique_ptr makeExprChecked(const IFunction &func, const ArgumentsPtrVector &args); + friend std::unique_ptr makeExpr(const IFunction &func, const ArgumentsPtrVector &args); friend ArgumentPtr parseExpr(const std::string &str); diff --git a/include/fintamath/expressions/ExpressionUtils.hpp b/include/fintamath/expressions/ExpressionUtils.hpp index 271659765..efd802772 100644 --- a/include/fintamath/expressions/ExpressionUtils.hpp +++ b/include/fintamath/expressions/ExpressionUtils.hpp @@ -28,4 +28,6 @@ bool hasVariable(const std::shared_ptr &expr, const Variable std::vector argumentVectorToStringVector(const ArgumentsPtrVector &args); +ArgumentsPtrVector argumentRefVectorToArgumentPtrVector(const ArgumentsRefVector &args); + } diff --git a/include/fintamath/functions/FunctionUtils.hpp b/include/fintamath/functions/FunctionUtils.hpp index 1fa99cb9d..04647682b 100644 --- a/include/fintamath/functions/FunctionUtils.hpp +++ b/include/fintamath/functions/FunctionUtils.hpp @@ -13,10 +13,14 @@ class IFunction; extern bool isExpression(const IMathObject &arg); +extern std::unique_ptr makeExprChecked(const IFunction &func, const ArgumentsPtrVector &args); + extern std::unique_ptr makeExprChecked(const IFunction &func, const ArgumentsRefVector &args); extern std::unique_ptr makeExpr(const IFunction &func, const ArgumentsPtrVector &args); +extern std::unique_ptr makeExpr(const IFunction &func, const ArgumentsRefVector &args); + template >> ArgumentPtr toArgumentPtr(T &arg) { if constexpr (std::is_copy_constructible_v) { diff --git a/include/fintamath/functions/IFunctionCRTP.hpp b/include/fintamath/functions/IFunctionCRTP.hpp index 74731972a..89e339515 100644 --- a/include/fintamath/functions/IFunctionCRTP.hpp +++ b/include/fintamath/functions/IFunctionCRTP.hpp @@ -56,7 +56,11 @@ class IFunctionCRTP_ : public IFunction { validateArgsSize(argsVect); if (doArgsMatch(argsVect)) { - return call(argsVect); + if (auto res = call(argsVect)) { + return res; + } + + return makeExpr(*this, argsVect); } return makeExprChecked(*this, argsVect); diff --git a/src/fintamath/expressions/Expression.cpp b/src/fintamath/expressions/Expression.cpp index e820e9c6c..f9f95c33a 100644 --- a/src/fintamath/expressions/Expression.cpp +++ b/src/fintamath/expressions/Expression.cpp @@ -497,17 +497,15 @@ void Expression::preciseRec(ArgumentPtr &arg, uint8_t precision) { } } -std::unique_ptr makeExprChecked(const IFunction &func, const ArgumentsRefVector &args) { - ArgumentsPtrVector argsPtrVect; - - for (const auto &arg : args) { - argsPtrVect.emplace_back(arg.get().clone()); - } - - Expression res(makeExpr(func, argsPtrVect)); +std::unique_ptr makeExprChecked(const IFunction &func, const ArgumentsPtrVector &args) { + Expression res(makeExpr(func, args)); return res.getChildren().front()->clone(); } +std::unique_ptr makeExprChecked(const IFunction &func, const ArgumentsRefVector &args) { + return makeExprChecked(func, argumentRefVectorToArgumentPtrVector(args)); +} + std::unique_ptr makeExpr(const IFunction &func, const ArgumentsPtrVector &args) { if (auto expr = Parser::parse(Expression::getExpressionMakers(), func.toString(), args)) { return expr; @@ -516,6 +514,10 @@ std::unique_ptr makeExpr(const IFunction &func, const ArgumentsPtrV return std::make_unique(func, args); } +std::unique_ptr makeExpr(const IFunction &func, const ArgumentsRefVector &args) { + return makeExpr(func, argumentRefVectorToArgumentPtrVector(args)); +} + void Expression::setChildren(const ArgumentsPtrVector &childVect) { if (childVect.size() != 1) { throw InvalidInputFunctionException("", argumentVectorToStringVector(childVect)); diff --git a/src/fintamath/expressions/ExpressionUtils.cpp b/src/fintamath/expressions/ExpressionUtils.cpp index db669832c..a81daab7f 100644 --- a/src/fintamath/expressions/ExpressionUtils.cpp +++ b/src/fintamath/expressions/ExpressionUtils.cpp @@ -141,8 +141,18 @@ std::vector argumentVectorToStringVector(const ArgumentsPtrVector & return argStrings; } +ArgumentsPtrVector argumentRefVectorToArgumentPtrVector(const ArgumentsRefVector &args) { + ArgumentsPtrVector argsPtrVect; + + for (const auto &arg : args) { + argsPtrVect.emplace_back(arg.get().clone()); + } + + return argsPtrVect; +} + bool isExpression(const IMathObject &arg) { return is(arg); } -} \ No newline at end of file +} diff --git a/src/fintamath/functions/other/Factorial.cpp b/src/fintamath/functions/other/Factorial.cpp index 579eb5baf..009afe8b2 100644 --- a/src/fintamath/functions/other/Factorial.cpp +++ b/src/fintamath/functions/other/Factorial.cpp @@ -38,7 +38,7 @@ std::unique_ptr Factorial::multiFactorialSimpl(const INumber &lhs, std::unique_ptr Factorial::factorialSimpl(const Integer &rhs, size_t order) { if (rhs < 0) { if (order != 1) { - return makeExpr(Factorial(order), rhs); + return {}; } return ComplexInf().clone(); @@ -53,7 +53,7 @@ std::unique_ptr Factorial::factorialSimpl(const Rational &rhs, size } if (order != 1) { - return makeExpr(Factorial(order), rhs); + return {}; } return factorialSimpl(Real(rhs), order); @@ -61,14 +61,14 @@ std::unique_ptr Factorial::factorialSimpl(const Rational &rhs, size std::unique_ptr Factorial::factorialSimpl(const Real &rhs, size_t order) { if (order != 1) { - return makeExpr(Factorial(order), rhs); + return {}; } try { return tgamma(rhs + 1).toMinimalObject(); } catch (const UndefinedException &) { - return makeExpr(Factorial(order), rhs); + return {}; } } diff --git a/src/fintamath/functions/powers/Pow.cpp b/src/fintamath/functions/powers/Pow.cpp index c3e20ce79..0a8587322 100644 --- a/src/fintamath/functions/powers/Pow.cpp +++ b/src/fintamath/functions/powers/Pow.cpp @@ -79,7 +79,7 @@ std::unique_ptr Pow::powSimpl(const Rational &lhs, const Rational & if (lhs < Integer(0)) { // TODO: complex numbers - return makeExpr(Pow(), lhs, rhs); + return {}; } if (lhsDenominator == 1) { @@ -94,7 +94,7 @@ std::unique_ptr Pow::powSimpl(const Real &lhs, const Real &rhs) { return pow(lhs, rhs).toMinimalObject(); } catch (const UndefinedException &) { - return makeExpr(Pow(), lhs, rhs); + return {}; } } diff --git a/src/fintamath/functions/powers/Root.cpp b/src/fintamath/functions/powers/Root.cpp index dbae8bbf7..58c145a97 100644 --- a/src/fintamath/functions/powers/Root.cpp +++ b/src/fintamath/functions/powers/Root.cpp @@ -26,7 +26,7 @@ std::unique_ptr Root::call(const ArgumentsRefVector &argsVect) cons if (rhsInt > Integer(1)) { if (lhs < Integer(0)) { // TODO: complex numbers - return makeExpr(Root(), lhs, rhs); + return {}; } return multiRootSimpl(lhs, rhsInt); diff --git a/tests/src/expressions/ExpressionUtilsTests.cpp b/tests/src/expressions/ExpressionUtilsTests.cpp index 79d6c5193..3379a060c 100644 --- a/tests/src/expressions/ExpressionUtilsTests.cpp +++ b/tests/src/expressions/ExpressionUtilsTests.cpp @@ -43,6 +43,10 @@ TEST(ExpressionUtilsTests, argumentVectorToStringVectorTest) { // TODO: implement } +TEST(ExpressionUtilsTests, argumentRefVectorToArgumentPtrVector) { + // TODO: implement +} + TEST(ExpressionUtilsTests, hasVariableTest) { auto expr = std::make_shared("cos(sin(a))"); EXPECT_TRUE(hasVariable(expr, Variable("a"))); diff --git a/tests/src/functions/FunctionUtilsTests.cpp b/tests/src/functions/FunctionUtilsTests.cpp index 0fbeeba14..f5571f6bd 100644 --- a/tests/src/functions/FunctionUtilsTests.cpp +++ b/tests/src/functions/FunctionUtilsTests.cpp @@ -18,7 +18,25 @@ using namespace fintamath; -TEST(FunctionUtilsTests, makeExpressionCheckedTest) { +TEST(FunctionUtilsTests, makeExpressionCheckedPtrsTest) { + ArgumentPtr one = std::make_unique(1); + ArgumentPtr two = std::make_unique(2); + auto expr1 = makeExprChecked(Add(), {one, two}); + EXPECT_EQ(expr1->toString(), "3"); + EXPECT_TRUE(is(expr1)); + + ArgumentPtr var = std::make_unique("a"); + ArgumentPtr expr2 = makeExprChecked(Cos(), {var}); + EXPECT_EQ(expr2->toString(), "cos(a)"); + EXPECT_TRUE(is(expr2)); + + EXPECT_THROW(makeExprChecked(Mul(), ArgumentsPtrVector{var})->toString(), InvalidInputException); + EXPECT_THROW(makeExprChecked(Mul(), ArgumentsPtrVector{})->toString(), InvalidInputException); + EXPECT_THROW(makeExprChecked(Pow(), ArgumentsPtrVector{var})->toString(), InvalidInputException); + EXPECT_THROW(makeExprChecked(Pow(), ArgumentsPtrVector{})->toString(), InvalidInputException); +} + +TEST(FunctionUtilsTests, makeExpressionCheckedRefsTest) { Integer one = 1; Integer two = 2; auto expr1 = makeExprChecked(Add(), {one, two}); @@ -29,24 +47,47 @@ TEST(FunctionUtilsTests, makeExpressionCheckedTest) { auto expr2 = makeExprChecked(Cos(), {var}); EXPECT_EQ(expr2->toString(), "cos(a)"); EXPECT_TRUE(is(expr2)); + + EXPECT_THROW(makeExprChecked(Mul(), ArgumentsRefVector{var})->toString(), InvalidInputException); + EXPECT_THROW(makeExprChecked(Mul(), ArgumentsRefVector{})->toString(), InvalidInputException); + EXPECT_THROW(makeExprChecked(Pow(), ArgumentsRefVector{var})->toString(), InvalidInputException); + EXPECT_THROW(makeExprChecked(Pow(), ArgumentsRefVector{})->toString(), InvalidInputException); } -TEST(FunctionUtilsTests, makeExpressionTest) { - auto expr1 = makeExpr(Add(), {std::make_shared(1), std::make_shared(2)}); +TEST(FunctionUtilsTests, makeExpressionPtrsTest) { + ArgumentPtr one = std::make_unique(1); + ArgumentPtr two = std::make_unique(2); + auto expr1 = makeExpr(Add(), {one, two}); EXPECT_EQ(expr1->toString(), "1 + 2"); - EXPECT_TRUE(is(expr1)); - EXPECT_FALSE(is(expr1)); + EXPECT_FALSE(is(expr1)); - auto var = std::make_shared("a"); + ArgumentPtr var = std::make_unique("a"); + ArgumentPtr expr2 = makeExpr(Cos(), {var}); + EXPECT_EQ(expr2->toString(), "cos(a)"); + EXPECT_TRUE(is(expr2)); + + EXPECT_THROW(makeExpr(Mul(), ArgumentsPtrVector{var})->toString(), InvalidInputException); + EXPECT_THROW(makeExpr(Mul(), ArgumentsPtrVector{})->toString(), InvalidInputException); + EXPECT_THROW(makeExpr(Pow(), ArgumentsPtrVector{var})->toString(), InvalidInputException); + EXPECT_THROW(makeExpr(Pow(), ArgumentsPtrVector{})->toString(), InvalidInputException); +} + +TEST(FunctionUtilsTests, makeExpressionRefsTest) { + Integer one = 1; + Integer two = 2; + auto expr1 = makeExpr(Add(), {one, two}); + EXPECT_EQ(expr1->toString(), "1 + 2"); + EXPECT_FALSE(is(expr1)); + + Variable var("a"); auto expr2 = makeExpr(Cos(), {var}); EXPECT_EQ(expr2->toString(), "cos(a)"); EXPECT_TRUE(is(expr2)); - EXPECT_FALSE(is(expr2)); - EXPECT_THROW(makeExpr(Mul(), {var})->toString(), InvalidInputException); - EXPECT_THROW(makeExpr(Mul(), {})->toString(), InvalidInputException); - EXPECT_THROW(makeExpr(Pow(), {var})->toString(), InvalidInputException); - EXPECT_THROW(makeExpr(Pow(), {})->toString(), InvalidInputException); + EXPECT_THROW(makeExpr(Mul(), ArgumentsRefVector{var})->toString(), InvalidInputException); + EXPECT_THROW(makeExpr(Mul(), ArgumentsRefVector{})->toString(), InvalidInputException); + EXPECT_THROW(makeExpr(Pow(), ArgumentsRefVector{var})->toString(), InvalidInputException); + EXPECT_THROW(makeExpr(Pow(), ArgumentsRefVector{})->toString(), InvalidInputException); } TEST(FunctionUtilsTests, makeExpressionCheckedAnyArgsTest) { diff --git a/tests/src/functions/powers/PowTests.cpp b/tests/src/functions/powers/PowTests.cpp index dbe4725af..8e42f2816 100644 --- a/tests/src/functions/powers/PowTests.cpp +++ b/tests/src/functions/powers/PowTests.cpp @@ -75,7 +75,7 @@ TEST(PowTests, callTest) { EXPECT_EQ(f(Real("2.2"), Real("0.5"))->toString(), "1.48323969741913258974227948816014261219598086381950031974652465286876603686277"); - EXPECT_EQ(f(Rational(-10), Rational("-1.5"))->toString(), "(-1/10)^(3/2)"); + EXPECT_EQ(f(Rational(-10), Rational("-1.5"))->toString(), "(-10)^(-3/2)"); EXPECT_EQ(f(Integer(0), Integer(-1))->toString(), "ComplexInf"); EXPECT_EQ(f(Integer(0), Integer(-10))->toString(), "ComplexInf"); diff --git a/tests/src/functions/powers/RootTests.cpp b/tests/src/functions/powers/RootTests.cpp index 279058217..ce9641ffb 100644 --- a/tests/src/functions/powers/RootTests.cpp +++ b/tests/src/functions/powers/RootTests.cpp @@ -94,7 +94,7 @@ TEST(RootTests, callTest) { EXPECT_EQ(f(Integer(-10), Integer(2))->toString(), "sqrt(-10)"); EXPECT_EQ(f(Rational(-9289, 10), Rational(2, 3))->toString(), "(-9289/10)^(3/2)"); - EXPECT_EQ(f(Real(-9289), Rational(2, 3))->toString(), "(-9289.0)^1.5"); + EXPECT_EQ(f(Real(-9289), Rational(2, 3))->toString(), "(-9289.0)^(3/2)"); EXPECT_THROW(f(), InvalidInputFunctionException); EXPECT_THROW(f(Integer(1), Integer(1), Integer(1)), InvalidInputFunctionException);