From 5421b05dd2e8353a75cdaf64f394b36a82b115aa Mon Sep 17 00:00:00 2001 From: fintarin Date: Mon, 25 Sep 2023 18:23:48 +0300 Subject: [PATCH] Use shunting yard algorithm for parsing --- include/fintamath/core/MathObjectTypes.hpp | 1 + include/fintamath/expressions/Expression.hpp | 29 +- include/fintamath/functions/IOperator.hpp | 10 +- include/fintamath/functions/other/Comma.hpp | 27 ++ src/fintamath/config/ParserConfig.cpp | 2 + src/fintamath/expressions/Expression.cpp | 396 ++++++++----------- src/fintamath/functions/other/Comma.cpp | 12 + tests/src/expressions/ExpressionTests.cpp | 9 +- tests/src/functions/other/CommaTests.cpp | 67 ++++ tests/src/functions/other/IndexTests.cpp | 3 + 10 files changed, 290 insertions(+), 266 deletions(-) create mode 100644 include/fintamath/functions/other/Comma.hpp create mode 100644 src/fintamath/functions/other/Comma.cpp create mode 100644 tests/src/functions/other/CommaTests.cpp diff --git a/include/fintamath/core/MathObjectTypes.hpp b/include/fintamath/core/MathObjectTypes.hpp index 3c04c6054..2966289cb 100644 --- a/include/fintamath/core/MathObjectTypes.hpp +++ b/include/fintamath/core/MathObjectTypes.hpp @@ -131,6 +131,7 @@ class MathObjectType { Nequiv, Index, Deg, + Comma, None = std::numeric_limits::max(), }; diff --git a/include/fintamath/expressions/Expression.hpp b/include/fintamath/expressions/Expression.hpp index bc0ffeaa6..607df66af 100644 --- a/include/fintamath/expressions/Expression.hpp +++ b/include/fintamath/expressions/Expression.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include "fintamath/core/CoreConstants.hpp" #include "fintamath/core/IArithmetic.hpp" #include "fintamath/expressions/IExpression.hpp" @@ -22,6 +24,7 @@ struct Term { }; using TermVector = std::vector>; +using OperandStack = std::stack; class Expression : public IExpressionCRTP { public: @@ -80,17 +83,13 @@ class Expression : public IExpressionCRTP { void updateStringMutable() const; - bool parseOperator(const TermVector &terms, size_t start, size_t end); - - bool parseFunction(const TermVector &terms, size_t start, size_t end); - - bool parseBrackets(const TermVector &terms, size_t start, size_t end); + static TermVector tokensToTerms(const TokenVector &tokens); - bool parseTerm(const TermVector &terms, size_t start, size_t end); + static OperandStack termsToOperands(const TermVector &terms); - static ArgumentPtrVector parseFunctionArgs(const TermVector &terms, size_t start, size_t end); + static ArgumentPtr operandsToExpr(OperandStack &operands); - static TermVector tokensToTerms(const TokenVector &tokens); + static ArgumentPtrVector unwrapComma(const ArgumentPtr &child); static void insertMultiplications(TermVector &terms); @@ -102,12 +101,6 @@ class Expression : public IExpressionCRTP { static bool canPrevTermBeBinaryOperator(const Term &term); - static bool skipBrackets(const TermVector &terms, size_t &openBracketIndex); - - static void cutBrackets(const TermVector &terms, size_t &start, size_t &end); - - static std::string termsToString(const TermVector &terms); - static bool isBinaryOperator(const ArgumentPtr &val); static bool isPrefixOperator(const ArgumentPtr &val); @@ -128,10 +121,6 @@ class Expression : public IExpressionCRTP { friend ArgumentPtr parseExpr(const std::string &str); - friend ArgumentPtr parseExpr(const TermVector &terms); - - friend ArgumentPtr parseExpr(const TermVector &terms, size_t start, size_t end); - static Parser::Vector, const Token &> &getTermMakers(); static Parser::Map, const ArgumentPtrVector &> &getExpressionMakers(); @@ -148,10 +137,6 @@ class Expression : public IExpressionCRTP { ArgumentPtr parseExpr(const std::string &str); -ArgumentPtr parseExpr(const TermVector &terms); - -ArgumentPtr parseExpr(const TermVector &terms, size_t start, size_t end); - Expression operator+(const Variable &lhs, const Variable &rhs); Expression operator+(const Expression &lhs, const Variable &rhs); diff --git a/include/fintamath/functions/IOperator.hpp b/include/fintamath/functions/IOperator.hpp index f0ff63f17..4e8af1a9f 100644 --- a/include/fintamath/functions/IOperator.hpp +++ b/include/fintamath/functions/IOperator.hpp @@ -9,6 +9,7 @@ class IOperator : public IFunction { public: enum class Priority : uint16_t { + Highest, Exponentiation, // e.g. a ^ b PostfixUnary, // e.g. a! PrefixUnary, // e.g. -a @@ -19,7 +20,8 @@ class IOperator : public IFunction { Disjunction, // e.g. a | b Implication, // e.g. a -> b Equivalence, // e.g. a <-> b - Any, + Comma, // e.g. a , b + Lowest, }; public: @@ -33,9 +35,9 @@ class IOperator : public IFunction { Parser::registerType(getParser()); } - static std::unique_ptr parse(const std::string &parsedStr, IOperator::Priority priority = IOperator::Priority::Any) { + static std::unique_ptr parse(const std::string &parsedStr, IOperator::Priority priority = IOperator::Priority::Lowest) { Parser::Comparator &> comp = [priority](const std::unique_ptr &oper) { - return priority == IOperator::Priority::Any || oper->getOperatorPriority() == priority; + return priority == IOperator::Priority::Lowest || oper->getOperatorPriority() == priority; }; return Parser::parse>(getParser(), comp, parsedStr); } @@ -55,7 +57,7 @@ class IOperatorCRTP : public IOperator { #undef I_OPERATOR_CRTP public: - explicit IOperatorCRTP(IOperator::Priority inPriority = IOperator::Priority::Any, + explicit IOperatorCRTP(IOperator::Priority inPriority = IOperator::Priority::Lowest, bool isAssociative = false, bool isNonExressionEvaluatable = true) : isNonExressionEvaluatableFunc(isNonExressionEvaluatable), diff --git a/include/fintamath/functions/other/Comma.hpp b/include/fintamath/functions/other/Comma.hpp new file mode 100644 index 000000000..5a45ebd01 --- /dev/null +++ b/include/fintamath/functions/other/Comma.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "fintamath/core/IArithmetic.hpp" +#include "fintamath/functions/IOperator.hpp" + +namespace fintamath { + +class Comma : public IOperatorCRTP { +public: + Comma() : IOperatorCRTP(IOperator::Priority::Comma, true) { + } + + std::string toString() const override { + return ","; + } + + static MathObjectType getTypeStatic() { + return MathObjectType::Comma; + } + +protected: + std::unique_ptr call(const ArgumentRefVector &argsVect) const override; +}; + +FINTAMATH_FUNCTION_EXPRESSION(Comma, commaExpr); + +} diff --git a/src/fintamath/config/ParserConfig.cpp b/src/fintamath/config/ParserConfig.cpp index 403a3a8c8..b63bd2926 100644 --- a/src/fintamath/config/ParserConfig.cpp +++ b/src/fintamath/config/ParserConfig.cpp @@ -42,6 +42,7 @@ #include "fintamath/functions/logic/Nequiv.hpp" #include "fintamath/functions/logic/Not.hpp" #include "fintamath/functions/logic/Or.hpp" +#include "fintamath/functions/other/Comma.hpp" #include "fintamath/functions/other/Deg.hpp" #include "fintamath/functions/other/Factorial.hpp" #include "fintamath/functions/other/Index.hpp" @@ -230,6 +231,7 @@ struct ParserConfig { IOperator::registerType(); IOperator::registerType(); IOperator::registerType(); + IOperator::registerType(); IExpression::registerType(); } diff --git a/src/fintamath/expressions/Expression.cpp b/src/fintamath/expressions/Expression.cpp index 6ec826f17..2a7c82e24 100644 --- a/src/fintamath/expressions/Expression.cpp +++ b/src/fintamath/expressions/Expression.cpp @@ -7,6 +7,7 @@ #include "fintamath/functions/arithmetic/Mul.hpp" #include "fintamath/functions/arithmetic/Neg.hpp" #include "fintamath/functions/arithmetic/Sub.hpp" +#include "fintamath/functions/other/Comma.hpp" #include "fintamath/functions/other/Factorial.hpp" #include "fintamath/literals/Variable.hpp" #include "fintamath/literals/constants/IConstant.hpp" @@ -15,10 +16,24 @@ namespace fintamath { +struct TermWithPriority { + std::shared_ptr term; + + IOperator::Priority priority = IOperator::Priority::Lowest; + +public: + TermWithPriority() = default; + + TermWithPriority(std::shared_ptr inTerm, IOperator::Priority inPriority) + : term(std::move(inTerm)), + priority(inPriority) { + } +}; + Expression::Expression() : child(Integer(0).clone()) { } -Expression::Expression(const std::string &str) : child(fintamath::parseExpr(str)) { +Expression::Expression(const std::string &str) : child(parseExpr(str)) { } Expression::Expression(const ArgumentPtr &obj) : child(compress(obj)) { @@ -46,187 +61,6 @@ Expression Expression::precise(uint8_t precision) const { return preciseExpr; } -ArgumentPtr parseExpr(const std::string &str) { - return parseExpr(Expression::tokensToTerms(Tokenizer::tokenize(str))); -} - -ArgumentPtr parseExpr(const TermVector &terms) { - return parseExpr(terms, 0, terms.size()); -} - -ArgumentPtr parseExpr(const TermVector &terms, size_t start, size_t end) { - if (start >= end) { - throw InvalidInputException(Expression::termsToString(terms)); - } - - Expression res; - - if (res.parseOperator(terms, start, end) || - res.parseFunction(terms, start, end) || - res.parseBrackets(terms, start, end) || - res.parseTerm(terms, start, end)) { - - return res.child; - } - - throw InvalidInputException(Expression::termsToString(terms)); -} - -bool Expression::parseOperator(const TermVector &terms, size_t start, size_t end) { - // TODO! use more efficient algorithm - - size_t foundOperPos = std::numeric_limits::max(); - IOperator::Priority foundOperPriority = IOperator::Priority::Any; - bool isPreviousTermBinaryOper = false; - - for (size_t i = start; i < end; i++) { - if (skipBrackets(terms, i)) { - isPreviousTermBinaryOper = false; - i--; - continue; - } - - if (auto oper = cast(terms[i]->value)) { - if (!isPreviousTermBinaryOper) { - IOperator::Priority priority = oper->getOperatorPriority(); - - bool newOperFound = foundOperPriority == IOperator::Priority::Any || - foundOperPriority < priority; - - newOperFound = newOperFound || (isBinaryOperator(oper) && - foundOperPriority == priority && - foundOperPos < i); - - if (newOperFound) { - foundOperPriority = priority; - foundOperPos = i; - } - - isPreviousTermBinaryOper = isBinaryOperator(oper); - } - } - else { - isPreviousTermBinaryOper = false; - } - } - - if (foundOperPos == std::numeric_limits::max()) { - return false; - } - - auto foundOper = cast(terms[foundOperPos]->value); - - if (isPrefixOperator(foundOper)) { - ArgumentPtr arg = parseExpr(terms, start + 1, end); - child = makeExpr(*foundOper, arg); - } - else if (isPostfixOperator(foundOper)) { - ArgumentPtr arg = parseExpr(terms, start, end - 1); - child = makeExpr(*foundOper, arg); - } - else { - ArgumentPtr lhsArg = parseExpr(terms, start, foundOperPos); - ArgumentPtr rhsArg = parseExpr(terms, foundOperPos + 1, end); - child = makeExpr(*foundOper, lhsArg, rhsArg); - } - - child = compress(child); - - return true; -} - -bool Expression::parseFunction(const TermVector &terms, size_t start, size_t end) { - const auto &term = terms[start]; - - if (start + 1 >= end) { - return false; - } - - if (auto termFirstValue = term->value; - !is(termFirstValue) || is(termFirstValue)) { - - return false; - } - - start++; - cutBrackets(terms, start, end); - - ArgumentPtrVector args = parseFunctionArgs(terms, start, end); - std::shared_ptr func = cast(term->value); - - if (func->getFunctionType() != IFunction::Type(args.size()) && - func->getFunctionType() != IFunction::Type::Any) { - - if (auto newFunc = IFunction::parse(term->name, IFunction::Type(args.size()))) { - func = std::move(newFunc); - } - else { - return false; - } - } - - child = makeExpr(*func, args); - - return true; -} - -ArgumentPtrVector Expression::parseFunctionArgs(const TermVector &terms, size_t start, size_t end) { - if (start >= end) { - return {}; - } - - ArgumentPtrVector funcArgs; - - for (size_t i = start; i < end; i++) { - if (terms[i]->name == "(") { - skipBrackets(terms, i); - } - - if (terms[i]->name == ",") { - if (i == 0 || i + 1 == end) { - throw InvalidInputException(termsToString(terms)); - } - - ArgumentPtr lhsArg = parseExpr(terms, start, i); - ArgumentPtrVector rhsArgs = parseFunctionArgs(terms, i + 1, end); - - funcArgs.emplace_back(lhsArg); - - for (const auto &arg : rhsArgs) { - funcArgs.emplace_back(ArgumentPtr(arg)); - } - - return funcArgs; - } - } - - funcArgs.emplace_back(parseExpr(terms, start, end)); - return funcArgs; -} - -bool Expression::parseBrackets(const TermVector &terms, size_t start, size_t end) { - if (start + 2 >= end) { - return false; - } - - if (terms[start]->name == "(" && terms[end - 1]->name == ")") { - cutBrackets(terms, start, end); - child = parseExpr(terms, start, end); - return true; - } - - return false; -} - -bool Expression::parseTerm(const TermVector &terms, size_t start, size_t end) { - if (start + 1 != end || !terms[start]->value) { - return false; - } - - child = terms[start]->value; - return true; -} - const std::shared_ptr &Expression::getFunction() const { static const std::shared_ptr func; return func; @@ -306,9 +140,22 @@ void Expression::updateStringMutable() const { stringCached = child->toString(); } +ArgumentPtr parseExpr(const std::string &str) { + try { + auto tokens = Tokenizer::tokenize(str); + auto terms = Expression::tokensToTerms(tokens); + auto stack = Expression::termsToOperands(terms); + auto expr = Expression::operandsToExpr(stack); + return expr; + } + catch (const InvalidInputException &) { + throw InvalidInputException(str); + } +} + TermVector Expression::tokensToTerms(const TokenVector &tokens) { if (tokens.empty()) { - return {}; + throw InvalidInputException(""); } TermVector terms(tokens.size()); @@ -329,6 +176,121 @@ TermVector Expression::tokensToTerms(const TokenVector &tokens) { return terms; } +// Use the shunting yard algorithm +// https://en.m.wikipedia.org/wiki/Shunting_yard_algorithm +OperandStack Expression::termsToOperands(const TermVector &terms) { + std::stack outStack; + std::stack operStack; + + for (const auto &term : terms) { + if (!term->value) { + if (term->name == "(") { + operStack.emplace(term, IOperator::Priority::Lowest); + } + else if (term->name == ")") { + while (!operStack.empty() && + operStack.top().term->name != "(") { + + outStack.emplace(operStack.top().term->value); + operStack.pop(); + } + + if (operStack.empty()) { + throw InvalidInputException(""); + } + + operStack.pop(); + } + else { + throw InvalidInputException(""); + } + } + else { + if (is(term->value)) { + if (auto oper = cast(term->value)) { + while (!operStack.empty() && + operStack.top().term->name != "(" && + operStack.top().priority <= oper->getOperatorPriority() && + !isPrefixOperator(oper)) { + + outStack.emplace(operStack.top().term->value); + operStack.pop(); + } + + operStack.emplace(term, oper->getOperatorPriority()); + } + else { + operStack.emplace(term, IOperator::Priority::Highest); + } + } + else { + outStack.emplace(term->value); + } + } + } + + while (!operStack.empty()) { + if (operStack.top().term->name == "(") { + throw InvalidInputException(""); + } + + outStack.emplace(operStack.top().term->value); + operStack.pop(); + } + + return outStack; +} + +ArgumentPtr Expression::operandsToExpr(OperandStack &operands) { + if (operands.empty()) { + throw InvalidInputException(""); + } + + ArgumentPtr arg = operands.top(); + operands.pop(); + + if (auto func = cast(arg)) { + ArgumentPtr rhsChild = operandsToExpr(operands); + + if (isBinaryOperator(func)) { + ArgumentPtr lhsChild = operandsToExpr(operands); + return makeExpr(*func, {lhsChild, rhsChild}); + } + + ArgumentPtrVector children = unwrapComma(rhsChild); + + if (func->getFunctionType() != IFunction::Type::Any && + size_t(func->getFunctionType()) != children.size()) { + + func = IFunction::parse(func->toString(), IFunction::Type(children.size())); + + if (!func) { + throw InvalidInputException(""); + } + } + + return makeExpr(*func, children); + } + + return arg; +} + +ArgumentPtrVector Expression::unwrapComma(const ArgumentPtr &child) { + if (const auto childExpr = cast(child); + childExpr && + is(childExpr->getFunction())) { + + const ArgumentPtr &lhs = childExpr->getChildren().front(); + const ArgumentPtr &rhs = childExpr->getChildren().back(); + + ArgumentPtrVector children = unwrapComma(lhs); + children.push_back(rhs); + return children; + } + + return {child}; +} + void Expression::insertMultiplications(TermVector &terms) { static const ArgumentPtr mul = Mul().clone(); @@ -346,7 +308,7 @@ void Expression::insertMultiplications(TermVector &terms) { void Expression::fixOperatorTypes(TermVector &terms) { bool isFixed = true; - if (const auto &term = terms.front(); + if (auto &term = terms.front(); is(term->value) && !isPrefixOperator(term->value)) { @@ -354,7 +316,7 @@ void Expression::fixOperatorTypes(TermVector &terms) { isFixed = isFixed && term->value; } - if (const auto &term = terms.back(); + if (auto &term = terms.back(); is(term->value) && !isPostfixOperator(term->value)) { @@ -362,12 +324,16 @@ void Expression::fixOperatorTypes(TermVector &terms) { isFixed = isFixed && term->value; } + if (!isFixed) { + throw InvalidInputException(""); + } + if (terms.size() < 3) { return; } for (auto i : std::views::iota(1U, terms.size() - 1)) { - const auto &term = terms[i]; + auto &term = terms[i]; const auto &termPrev = terms[i - 1]; if (is(term->value) && @@ -381,7 +347,7 @@ void Expression::fixOperatorTypes(TermVector &terms) { // TODO: use reverse(iota(1, terms.size() - 1)) when it is work for (size_t i = terms.size() - 2; i > 0; i--) { - const auto &term = terms[i]; + auto &term = terms[i]; const auto &termNext = terms[i + 1]; if (is(term->value) && @@ -394,13 +360,13 @@ void Expression::fixOperatorTypes(TermVector &terms) { } if (!isFixed) { - throw InvalidInputException(termsToString(terms)); + throw InvalidInputException(""); } } void Expression::collapseFactorials(TermVector &terms) { for (size_t i = 1; i + 1 < terms.size(); i++) { - const auto &term = terms[i]; + auto &term = terms[i]; const auto &termNext = terms[i + 1]; if (auto factorial = cast(term->value); @@ -432,53 +398,6 @@ bool Expression::canPrevTermBeBinaryOperator(const Term &term) { term.name == ","); } -bool Expression::skipBrackets(const TermVector &terms, size_t &openBracketIndex) { - if (openBracketIndex >= terms.size() || terms[openBracketIndex]->name != "(") { - return false; - } - - int64_t brackets = 0; - - for (auto i : std::views::iota(openBracketIndex, terms.size())) { - const auto &term = terms[i]; - - if (term->name == "(") { - brackets++; - } - else if (term->name == ")") { - brackets--; - } - - if (brackets == 0) { - openBracketIndex = i + 1; - return true; - } - } - - throw InvalidInputException(termsToString(terms)); -} - -void Expression::cutBrackets(const TermVector &terms, size_t &start, size_t &end) { - if (start + 1 >= end) { - return; - } - - if (terms[start]->name == "(" && terms[end - 1]->name == ")") { - start++; - end--; - } -} - -std::string Expression::termsToString(const TermVector &terms) { - std::string res; - - for (const auto &term : terms) { - res += term->name; - } - - return res; -} - bool Expression::isBinaryOperator(const ArgumentPtr &val) { auto oper = cast(val); return oper && oper->getFunctionType() == IFunction::Type::Binary; @@ -529,8 +448,8 @@ ArgumentPtr Expression::compress(const ArgumentPtr &child) { } std::unique_ptr makeExpr(const IFunction &func, const ArgumentPtrVector &args) { - auto argsView = args | std::views::transform(&Expression::compress); - ArgumentPtrVector compressedArgs(argsView.begin(), argsView.end()); + ArgumentPtrVector compressedArgs = args; + std::ranges::transform(compressedArgs, compressedArgs.begin(), &Expression::compress); Expression::validateFunctionArgs(func, compressedArgs); @@ -649,4 +568,5 @@ Expression operator/(const Expression &lhs, const Variable &rhs) { Expression operator/(const Variable &lhs, const Expression &rhs) { return Expression(divExpr(lhs, rhs)); } + } diff --git a/src/fintamath/functions/other/Comma.cpp b/src/fintamath/functions/other/Comma.cpp new file mode 100644 index 000000000..41192feb4 --- /dev/null +++ b/src/fintamath/functions/other/Comma.cpp @@ -0,0 +1,12 @@ +#include "fintamath/functions/other/Comma.hpp" + +namespace fintamath { + +std::unique_ptr Comma::call(const ArgumentRefVector &argsVect) const { + const auto &lhs = argsVect.front().get(); + const auto &rhs = argsVect.back().get(); + + throw InvalidInputBinaryOperatorException(toString(), lhs.toString(), rhs.toString()); +} + +} diff --git a/tests/src/expressions/ExpressionTests.cpp b/tests/src/expressions/ExpressionTests.cpp index 82e20434d..ff05babb0 100644 --- a/tests/src/expressions/ExpressionTests.cpp +++ b/tests/src/expressions/ExpressionTests.cpp @@ -1441,11 +1441,12 @@ TEST(ExpressionTests, stringConstructorNegativeTest) { EXPECT_THROW(Expression("1-"), InvalidInputException); EXPECT_THROW(Expression("1*"), InvalidInputException); EXPECT_THROW(Expression("1/"), InvalidInputException); + EXPECT_THROW(Expression("*1"), InvalidInputException); + EXPECT_THROW(Expression("/1"), InvalidInputException); EXPECT_THROW(Expression(" + "), InvalidInputException); EXPECT_THROW(Expression("(1+2))"), InvalidInputException); EXPECT_THROW(Expression("5-*3"), InvalidInputException); EXPECT_THROW(Expression("5 3 +"), InvalidInputException); - EXPECT_THROW(Expression("((()()))"), InvalidInputException); EXPECT_THROW(Expression("2.2.2"), InvalidInputException); EXPECT_THROW(Expression("--"), InvalidInputException); EXPECT_THROW(Expression("."), InvalidInputException); @@ -1485,9 +1486,11 @@ TEST(ExpressionTests, stringConstructorNegativeTest) { EXPECT_THROW(Expression("(2"), InvalidInputException); EXPECT_THROW(Expression("((2)"), InvalidInputException); EXPECT_THROW(Expression("((2"), InvalidInputException); + EXPECT_THROW(Expression("((((2)((2))))"), InvalidInputException); EXPECT_THROW(Expression("(()())"), InvalidInputException); + EXPECT_THROW(Expression("((()()))"), InvalidInputException); EXPECT_THROW(Expression("((((()))))"), InvalidInputException); - EXPECT_THROW(Expression("((((2)((2))))"), InvalidInputException); + EXPECT_THROW(Expression("(,) + (,)"), InvalidInputException); EXPECT_THROW(Expression("!2"), InvalidInputException); EXPECT_THROW(Expression("!!2"), InvalidInputException); @@ -1502,6 +1505,8 @@ TEST(ExpressionTests, stringConstructorNegativeTest) { EXPECT_THROW(Expression("(a+b)*()"), InvalidInputException); EXPECT_THROW(Expression("sin(2,3)"), InvalidInputException); + EXPECT_THROW(Expression("sin(2,3) + 2"), InvalidInputException); + EXPECT_THROW(Expression("cos(sin(2,3))"), InvalidInputException); EXPECT_THROW(Expression("sin(,)"), InvalidInputException); EXPECT_THROW(Expression("sin(,2)"), InvalidInputException); EXPECT_THROW(Expression("sin(2,)"), InvalidInputException); diff --git a/tests/src/functions/other/CommaTests.cpp b/tests/src/functions/other/CommaTests.cpp new file mode 100644 index 000000000..24b845fa6 --- /dev/null +++ b/tests/src/functions/other/CommaTests.cpp @@ -0,0 +1,67 @@ +#include "gtest/gtest.h" + +#include "fintamath/functions/other/Comma.hpp" + +#include "fintamath/expressions/Expression.hpp" +#include "fintamath/functions/arithmetic/Sub.hpp" +#include "fintamath/functions/arithmetic/UnaryPlus.hpp" +#include "fintamath/literals/Variable.hpp" +#include "fintamath/numbers/Integer.hpp" + +using namespace fintamath; + +const Comma f; + +TEST(CommaTests, toStringTest) { + EXPECT_EQ(f.toString(), ","); +} + +TEST(CommaTests, getFunctionTypeTest) { + EXPECT_EQ(f.getFunctionType(), IFunction::Type::Binary); +} + +TEST(CommaTests, getOperatorPriorityTest) { + EXPECT_EQ(f.getOperatorPriority(), IOperator::Priority::Comma); +} + +TEST(CommaTests, isAssociativeTest) { + EXPECT_TRUE(f.isAssociative()); +} + +TEST(CommaTests, callTest) { + EXPECT_THROW(f(Variable("a"), Variable("a"))->toString(), InvalidInputException); + + EXPECT_THROW(f(), InvalidInputFunctionException); + EXPECT_THROW(f(Integer(1), Integer(1), Integer(1)), InvalidInputFunctionException); +} + +TEST(CommaTests, exprTest) { + EXPECT_EQ(commaExpr(Variable("a"), Integer(1))->toString(), "a , 1"); +} + +TEST(CommaTests, doArgsMatchTest) { + Variable a("a"); + Integer b(1); + + EXPECT_FALSE(f.doArgsMatch({})); + EXPECT_FALSE(f.doArgsMatch({a})); + EXPECT_TRUE(f.doArgsMatch({a, b})); + EXPECT_FALSE(f.doArgsMatch({a, b, b})); +} + +TEST(CommaTests, equalsTest) { + EXPECT_EQ(f, f); + EXPECT_EQ(f, Comma()); + EXPECT_EQ(Comma(), f); + EXPECT_EQ(f, cast(Comma())); + EXPECT_EQ(cast(Comma()), f); + EXPECT_NE(f, Sub()); + EXPECT_NE(Sub(), f); + EXPECT_NE(f, UnaryPlus()); + EXPECT_NE(UnaryPlus(), f); +} + +TEST(CommaTests, getTypeTest) { + EXPECT_EQ(Comma::getTypeStatic(), MathObjectType::Comma); + EXPECT_EQ(Comma().getType(), MathObjectType::Comma); +} diff --git a/tests/src/functions/other/IndexTests.cpp b/tests/src/functions/other/IndexTests.cpp index 56cac34d6..b99840326 100644 --- a/tests/src/functions/other/IndexTests.cpp +++ b/tests/src/functions/other/IndexTests.cpp @@ -42,6 +42,9 @@ TEST(IndexTests, callTest) { EXPECT_THROW(f(Expression("a+1"), Integer(2))->toString(), InvalidInputException); EXPECT_THROW(f(Expression("a+1"), Expression("a+1"))->toString(), InvalidInputException); EXPECT_THROW(f(Expression("a"), Expression("a>1"))->toString(), InvalidInputException); + + EXPECT_THROW(f(), InvalidInputFunctionException); + EXPECT_THROW(f(Integer(1), Integer(1), Integer(1)), InvalidInputFunctionException); } TEST(IndexTests, exprTest) {