Skip to content

Commit

Permalink
Remove unused values of IOperator::Priority. Refactor Expression
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Feb 17, 2024
1 parent 8e95cf0 commit 2042fa2
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 106 deletions.
27 changes: 24 additions & 3 deletions include/fintamath/expressions/Expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <stack>
#include <string>
#include <utility>
Expand All @@ -16,6 +17,7 @@
#include "fintamath/expressions/IExpression.hpp"
#include "fintamath/functions/FunctionArguments.hpp"
#include "fintamath/functions/IFunction.hpp"
#include "fintamath/functions/IOperator.hpp"
#include "fintamath/literals/Variable.hpp"

namespace fintamath {
Expand All @@ -36,7 +38,24 @@ struct Term final {
}
};

using TermVector = std::vector<std::unique_ptr<detail::Term>>;
struct FunctionTerm final {
Term term;

std::optional<IOperator::Priority> priority;

public:
FunctionTerm() = default;

FunctionTerm(Term inTerm, const std::optional<IOperator::Priority> inPriority)
: term(std::move(inTerm)),
priority(inPriority) {
}
};

using TermVector = std::vector<Term>;

using FunctionTermStack = std::stack<FunctionTerm>;

using OperandStack = std::stack<std::unique_ptr<IMathObject>>;

}
Expand Down Expand Up @@ -104,11 +123,11 @@ class Expression final : public IExpressionCRTP<Expression> {

static std::unique_ptr<IMathObject> operandsToObject(detail::OperandStack &operands);

static ArgumentPtrVector unwrapComma(const ArgumentPtr &child);
static void moveFunctionsToOperands(detail::OperandStack &operands, detail::FunctionTermStack &functions, const IOperator *nextOper);

static void insertMultiplications(detail::TermVector &terms);

static void fixOperatorTypes(const detail::TermVector &terms);
static void fixOperatorTypes(detail::TermVector &terms);

static void collapseFactorials(detail::TermVector &terms);

Expand All @@ -128,6 +147,8 @@ class Expression final : public IExpressionCRTP<Expression> {

static bool doesArgMatch(const MathObjectType &expectedType, const ArgumentPtr &arg);

static ArgumentPtrVector unwrapComma(const ArgumentPtr &child);

static ArgumentPtr compress(const ArgumentPtr &child);

friend std::unique_ptr<IMathObject> detail::makeExpr(const IFunction &func, ArgumentPtrVector args);
Expand Down
10 changes: 6 additions & 4 deletions include/fintamath/functions/IOperator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class IOperator : public IFunction {

public:
enum class Priority : uint8_t {
Highest,
Exponentiation, // e.g. a ^ b
PostfixUnary, // e.g. a!
PrefixUnary, // e.g. -a
Expand All @@ -30,7 +29,6 @@ class IOperator : public IFunction {
Implication, // e.g. a -> b
Equivalence, // e.g. a <-> b
Comma, // e.g. a , b
Lowest,
};

public:
Expand All @@ -45,9 +43,13 @@ class IOperator : public IFunction {
return getParser().parse(validator, parsedStr);
}

static std::unique_ptr<IOperator> parse(const std::string &parsedStr, Priority priority = Priority::Lowest) {
static std::unique_ptr<IOperator> parse(const std::string &parsedStr) {
return getParser().parse(parsedStr);
}

static std::unique_ptr<IOperator> parse(const std::string &parsedStr, Priority priority) {
const auto validator = [priority](const std::unique_ptr<IOperator> &oper) {
return priority == Priority::Lowest || oper->getPriority() == priority;
return oper->getPriority() == priority;
};
return getParser().parse(validator, parsedStr);
}
Expand Down
183 changes: 84 additions & 99 deletions src/fintamath/expressions/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <ranges>
#include <stack>
#include <string>
Expand Down Expand Up @@ -40,20 +41,6 @@ namespace fintamath {

using namespace detail;

struct TermWithPriority final {
std::unique_ptr<Term> term;

IOperator::Priority priority = IOperator::Priority::Lowest;

public:
TermWithPriority() = default;

TermWithPriority(std::unique_ptr<Term> inTerm, const IOperator::Priority inPriority)
: term(std::move(inTerm)),
priority(inPriority) {
}
};

Expression::Expression() : child(Integer(0).clone()) {
}

Expand Down Expand Up @@ -196,10 +183,10 @@ TermVector Expression::tokensToTerms(const TokenVector &tokens) {

for (const auto i : stdv::iota(0U, terms.size())) {
if (auto term = getTermParser().parse(tokens[i])) {
terms[i] = std::move(term);
terms[i] = std::move(*term);
}
else {
terms[i] = std::make_unique<Term>(tokens[i], std::unique_ptr<IMathObject>{});
terms[i] = Term(tokens[i], std::unique_ptr<IMathObject>{});
}
}

Expand All @@ -213,66 +200,49 @@ TermVector Expression::tokensToTerms(const TokenVector &tokens) {
// Use the shunting yard algorithm
// https://en.m.wikipedia.org/wiki/Shunting_yard_algorithm
OperandStack Expression::termsToOperands(TermVector &terms) {
OperandStack outStack;
std::stack<TermWithPriority> operStack;
OperandStack operands;
FunctionTermStack functions;

for (auto &term : terms) {
if (!term->value) {
if (term->name == "(") {
operStack.emplace(std::move(term), IOperator::Priority::Lowest);
if (!term.value) {
if (term.name == "(") {
functions.emplace(std::move(term), std::optional<IOperator::Priority>{});
}
else if (term->name == ")") {
while (!operStack.empty() &&
operStack.top().term->name != "(") {

outStack.emplace(std::move(operStack.top().term->value));
operStack.pop();
}
else if (term.name == ")") {
moveFunctionsToOperands(operands, functions, {});

if (operStack.empty()) {
if (functions.empty()) {
throw InvalidInputException("");
}

operStack.pop();
functions.pop();
}
else {
throw InvalidInputException("");
}
}
else {
if (is<IFunction>(term->value)) {
if (const auto *oper = cast<IOperator>(term->value.get())) {
while (!operStack.empty() &&
operStack.top().term->name != "(" &&
operStack.top().priority <= oper->getPriority() &&
!isPrefixOperator(oper)) {

outStack.emplace(std::move(operStack.top().term->value));
operStack.pop();
}

operStack.emplace(std::move(term), oper->getPriority());
}
else {
operStack.emplace(std::move(term), IOperator::Priority::Highest);
}
}
else {
outStack.emplace(std::move(term->value));
else if (is<IFunction>(term.value)) {
std::optional<IOperator::Priority> priority;

if (const auto *oper = cast<IOperator>(term.value.get())) {
moveFunctionsToOperands(operands, functions, oper);
priority = oper->getPriority();
}

functions.emplace(std::move(term), priority);
}
else {
operands.emplace(std::move(term.value));
}
}

while (!operStack.empty()) {
if (operStack.top().term->name == "(") {
throw InvalidInputException("");
}
moveFunctionsToOperands(operands, functions, {});

outStack.emplace(std::move(operStack.top().term->value));
operStack.pop();
if (!functions.empty()) {
throw InvalidInputException("");
}

return outStack;
return operands;
}

std::unique_ptr<IMathObject> Expression::operandsToObject(OperandStack &operands) {
Expand Down Expand Up @@ -308,53 +278,53 @@ std::unique_ptr<IMathObject> Expression::operandsToObject(OperandStack &operands
return arg;
}

ArgumentPtrVector Expression::unwrapComma(const ArgumentPtr &child) {
if (const auto childExpr = cast<IExpression>(child);
childExpr &&
is<Comma>(childExpr->getFunction())) {
void Expression::moveFunctionsToOperands(OperandStack &operands, std::stack<FunctionTerm> &functions, const IOperator *nextOper) {
if (isPrefixOperator(nextOper)) {
return;
}

const ArgumentPtr &lhs = childExpr->getChildren().front();
const ArgumentPtr &rhs = childExpr->getChildren().back();
while (!functions.empty() &&
functions.top().term.name != "(" &&
(!nextOper ||
!functions.top().priority ||
*functions.top().priority <= nextOper->getPriority())) {

ArgumentPtrVector children = unwrapComma(lhs);
children.push_back(rhs);
return children;
operands.emplace(std::move(functions.top().term.value));
functions.pop();
}

return {child};
}

void Expression::insertMultiplications(TermVector &terms) {
static const ArgumentPtr mul = Mul{}.clone();

for (size_t i = 1; i < terms.size(); i++) {
if (canNextTermBeBinaryOperator(*terms[i - 1]) &&
canPrevTermBeBinaryOperator(*terms[i])) {
if (canNextTermBeBinaryOperator(terms[i - 1]) &&
canPrevTermBeBinaryOperator(terms[i])) {

auto term = std::make_unique<Term>(mul->toString(), mul->clone());
Term term(mul->toString(), mul->clone());
terms.insert(terms.begin() + static_cast<ptrdiff_t>(i), std::move(term));
i++;
}
}
}

void Expression::fixOperatorTypes(const TermVector &terms) {
void Expression::fixOperatorTypes(TermVector &terms) {
bool isFixed = true;

if (const auto &term = terms.front();
is<IOperator>(term->value) &&
!isPrefixOperator(term->value.get())) {
if (auto &term = terms.front();
is<IOperator>(term.value) &&
!isPrefixOperator(term.value.get())) {

term->value = IOperator::parse(term->name, IOperator::Priority::PrefixUnary);
isFixed = static_cast<bool>(term->value);
term.value = IOperator::parse(term.name, IOperator::Priority::PrefixUnary);
isFixed = static_cast<bool>(term.value);
}

if (const auto &term = terms.back();
is<IOperator>(term->value) &&
!isPostfixOperator(term->value.get())) {
if (auto &term = terms.back();
is<IOperator>(term.value) &&
!isPostfixOperator(term.value.get())) {

term->value = IOperator::parse(term->name, IOperator::Priority::PostfixUnary);
isFixed = isFixed && static_cast<bool>(term->value);
term.value = IOperator::parse(term.name, IOperator::Priority::PostfixUnary);
isFixed = isFixed && static_cast<bool>(term.value);
}

if (!isFixed) {
Expand All @@ -366,28 +336,28 @@ void Expression::fixOperatorTypes(const TermVector &terms) {
}

for (const auto i : stdv::iota(1U, terms.size() - 1)) {
const auto &term = terms[i];
auto &term = terms[i];
const auto &termPrev = terms[i - 1];

if (is<IOperator>(term->value) &&
!isPrefixOperator(term->value.get()) &&
!canNextTermBeBinaryOperator(*termPrev)) {
if (is<IOperator>(term.value) &&
!isPrefixOperator(term.value.get()) &&
!canNextTermBeBinaryOperator(termPrev)) {

term->value = IOperator::parse(term->name, IOperator::Priority::PrefixUnary);
isFixed = isFixed && term->value;
term.value = IOperator::parse(term.name, IOperator::Priority::PrefixUnary);
isFixed = isFixed && term.value;
}
}

for (const auto i : stdv::iota(1U, terms.size() - 1) | stdv::reverse) {
const auto &term = terms[i];
auto &term = terms[i];
const auto &termNext = terms[i + 1];

if (is<IOperator>(term->value) &&
!isPostfixOperator(term->value.get()) &&
!canPrevTermBeBinaryOperator(*termNext)) {
if (is<IOperator>(term.value) &&
!isPostfixOperator(term.value.get()) &&
!canPrevTermBeBinaryOperator(termNext)) {

term->value = IOperator::parse(term->name, IOperator::Priority::PostfixUnary);
isFixed = isFixed && term->value;
term.value = IOperator::parse(term.name, IOperator::Priority::PostfixUnary);
isFixed = isFixed && term.value;
}
}

Expand All @@ -398,12 +368,12 @@ void Expression::fixOperatorTypes(const TermVector &terms) {

void Expression::collapseFactorials(TermVector &terms) {
for (size_t i = 1; i + 1 < terms.size(); i++) {
const auto &term = terms[i];
auto &term = terms[i];
const auto &termNext = terms[i + 1];

if (is<Factorial>(term->value) && is<Factorial>(termNext->value)) {
const auto &oldFactorial = cast<Factorial>(*term->value);
term->value = Factorial(oldFactorial.getOrder() + 1).clone();
if (is<Factorial>(term.value) && is<Factorial>(termNext.value)) {
const auto &oldFactorial = cast<Factorial>(*term.value);
term.value = Factorial(oldFactorial.getOrder() + 1).clone();

terms.erase(terms.begin() + static_cast<ptrdiff_t>(i) + 1);
i--;
Expand Down Expand Up @@ -445,6 +415,22 @@ bool Expression::isNonOperatorFunction(const IMathObject *val) {
return is<IFunction>(val) && !is<IOperator>(val);
}

ArgumentPtrVector Expression::unwrapComma(const ArgumentPtr &child) {
if (const auto childExpr = cast<IExpression>(child);
childExpr &&
is<Comma>(childExpr->getFunction())) {

const ArgumentPtr &lhs = childExpr->getChildren().front();
const ArgumentPtr &rhs = childExpr->getChildren().back();

ArgumentPtrVector children = unwrapComma(lhs);
children.push_back(rhs);
return children;
}

return {child};
}

ArgumentPtr Expression::compress(const ArgumentPtr &child) {
if (const auto expr = cast<Expression>(child)) {
return expr->child;
Expand Down Expand Up @@ -582,5 +568,4 @@ std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentRefVe
}

}

}

0 comments on commit 2042fa2

Please sign in to comment.