diff --git a/include/fintamath/expressions/ExpressionUtils.hpp b/include/fintamath/expressions/ExpressionUtils.hpp index 6edb3f49d..4a244ee9f 100644 --- a/include/fintamath/expressions/ExpressionUtils.hpp +++ b/include/fintamath/expressions/ExpressionUtils.hpp @@ -74,6 +74,8 @@ std::pair splitPowExpr(const ArgumentPtr &rhs); std::pair splitRational(const ArgumentPtr &arg); +ArgumentPtr negate(const ArgumentPtr &arg); + ArgumentPtr makePolynom(const IFunction &func, ArgumentPtrVector &&args); ArgumentPtr makePolynom(const IFunction &func, const ArgumentPtrVector &args); diff --git a/src/fintamath/expressions/ExpressionUtils.cpp b/src/fintamath/expressions/ExpressionUtils.cpp index b3530322c..ffe214ba4 100644 --- a/src/fintamath/expressions/ExpressionUtils.cpp +++ b/src/fintamath/expressions/ExpressionUtils.cpp @@ -223,6 +223,37 @@ std::pair splitRational(const ArgumentPtr &arg) { return {arg, Integer(1).clone()}; } +ArgumentPtr negate(const ArgumentPtr &arg) { + if (const auto expr = cast(arg)) { + if (is(expr->getFunction())) { + auto negChildrenView = + expr->getChildren() | + stdv::transform([](const ArgumentPtr &child) { + return negate(child); + }); + return makePolynom(Add{}, ArgumentPtrVector(negChildrenView.begin(), negChildrenView.end())); // TODO: use C++23 stdv::to + } + + if (is(expr->getFunction())) { + if (const auto firstChildNum = cast(expr->getChildren().front())) { + if (*firstChildNum == Integer(-1)) { + ArgumentPtrVector negChildren(expr->getChildren().begin() + 1, expr->getChildren().end()); + return makePolynom(Mul(), std::move(negChildren)); + } + + ArgumentPtrVector negChildren = expr->getChildren(); + negChildren.front() = (*firstChildNum) * Integer(-1); + return makePolynom(Mul(), std::move(negChildren)); + } + } + } + else if (const auto arithm = cast(arg)) { + return (*arithm) * Integer(-1); + } + + return mulExpr(Integer(-1).clone(), arg); +} + ArgumentPtr makePolynom(const IFunction &func, ArgumentPtrVector &&args) { if (args.empty()) { return {}; diff --git a/src/fintamath/expressions/binary/CompExpression.cpp b/src/fintamath/expressions/binary/CompExpression.cpp index 56d180ac7..802857234 100644 --- a/src/fintamath/expressions/binary/CompExpression.cpp +++ b/src/fintamath/expressions/binary/CompExpression.cpp @@ -47,17 +47,12 @@ std::string CompExpression::toString() const { *lhsExpr->getFunction() == Add{}) { ArgumentPtrVector sumChildren = lhsExpr->getChildren(); - const ArgumentPtr solLhs = sumChildren.front(); + ArgumentPtr solLhs = sumChildren.front(); if (is(solLhs)) { sumChildren.erase(sumChildren.begin()); - - ArgumentPtr solRhs = negExpr(std::move(sumChildren)); - simplifyChild(solRhs); - - if (!is(solRhs)) { - return CompExpression(cast(*func), solLhs, solRhs).toString(); - } + ArgumentPtr solRhs = detail::negate(makePolynom(Add{}, std::move(sumChildren))); + return CompExpression(cast(*func), std::move(solLhs), std::move(solRhs)).toString(); } } } diff --git a/src/fintamath/expressions/binary/DivExpression.cpp b/src/fintamath/expressions/binary/DivExpression.cpp index f0d211fb4..5ceeb9b0d 100644 --- a/src/fintamath/expressions/binary/DivExpression.cpp +++ b/src/fintamath/expressions/binary/DivExpression.cpp @@ -46,9 +46,8 @@ DivExpression::DivExpression(ArgumentPtr inLhsChild, ArgumentPtr inRhsChild) } std::string DivExpression::toString() const { - if (isNegated(lhsChild)) { // TODO! find more efficient solution - ArgumentPtr innerDiv = divExpr(negExpr(lhsChild)->toMinimalObject(), rhsChild); - return negExpr(std::move(innerDiv))->toString(); + if (isNegated(lhsChild)) { + return negExpr(divExpr(detail::negate(lhsChild), rhsChild))->toString(); } return IBinaryExpression::toString(); diff --git a/tests/src/expressions/ExpressionUtilsTests.cpp b/tests/src/expressions/ExpressionUtilsTests.cpp index 3c81689ec..ed30240ee 100644 --- a/tests/src/expressions/ExpressionUtilsTests.cpp +++ b/tests/src/expressions/ExpressionUtilsTests.cpp @@ -124,3 +124,7 @@ TEST(ExpressionUtilsTests, isNegativeNumberTest) { TEST(ExpressionUtilsTests, makePolynomTest) { // TODO: implement } + +TEST(ExpressionUtilsTests, negateTest) { + // TODO: implement +} diff --git a/tests/src/overall/simplify/SimplifyDerivativeTests.cpp b/tests/src/overall/simplify/SimplifyDerivativeTests.cpp index 0c6aec33e..4e5617fe7 100644 --- a/tests/src/overall/simplify/SimplifyDerivativeTests.cpp +++ b/tests/src/overall/simplify/SimplifyDerivativeTests.cpp @@ -282,7 +282,7 @@ TEST(SimplifyDerivativeTests, simplifyTest) { EXPECT_EQ(Expression("derivative(ln(cos(3x)), x)").toString(), "-3 tan(3 x)"); EXPECT_EQ(Expression("derivative(log(sin(x^5), tan(x^3)), x)").toString(), - "(3 sec(x^3)^2 x^2 cos(x^3) csc(x^3))/ln(sin(x^5)) - (5 x^4 cot(x^5) ln(tan(x^3)))/(ln(sin(x^5))^2)"); + "(3 sec(x^3)^2 x^2 cos(x^3) csc(x^3))/ln(sin(x^5)) - (5 x^4 cos(x^5) csc(x^5) ln(tan(x^3)))/(ln(sin(x^5))^2)"); EXPECT_EQ(Expression("derivative(acos(4x + 5)^5, x)").toString(), "-(20 acos(4 x + 5)^4)/sqrt(-16 x^2 - 40 x - 24)"); EXPECT_EQ(Expression("derivative(sin(sin(sin(x))), x)").toString(),