Skip to content

Commit

Permalink
Lazy computations in Expression
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Sep 25, 2023
1 parent 45af3b9 commit adfb065
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 150 deletions.
65 changes: 39 additions & 26 deletions include/fintamath/expressions/Expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,7 @@ class Expression : public IExpressionCRTP<Expression> {
}

template <typename Function, bool isPolynomial = false>
static void registerFunctionExpressionMaker(auto &&maker) {
Parser::Function<std::unique_ptr<IMathObject>, const ArgumentPtrVector &> constructor =
[maker = std::forward<decltype(maker)>(maker)](const ArgumentPtrVector &args) {
static const IFunction::Type type = Function().getFunctionType();
std::unique_ptr<IMathObject> res;

if constexpr (IsFunctionTypeAny<Function>::value) {
res = maker(args);
}
else if constexpr (isPolynomial) {
if (size_t(type) <= args.size()) {
res = maker(args);
}
}
else {
if (size_t(type) == args.size()) {
res = maker(args);
}
}

return res;
};

Parser::add<Function>(getExpressionMakers(), std::move(constructor));
}
static void registerFunctionExpressionMaker(auto &&maker);

static MathObjectType getTypeStatic() {
return MathObjectType::Expression;
Expand All @@ -100,6 +76,10 @@ class Expression : public IExpressionCRTP<Expression> {
ArgumentPtr preciseSimplify() const override;

private:
void simplifyMutable() const;

void updateStringMutable() const;

bool parseOperator(const TermVector &terms, size_t start, size_t end);

bool parseFunction(const TermVector &terms, size_t start, size_t end);
Expand Down Expand Up @@ -142,6 +122,8 @@ class Expression : public IExpressionCRTP<Expression> {

static void preciseRec(ArgumentPtr &arg, uint8_t precision);

static ArgumentPtr compress(const ArgumentPtr &child);

friend std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentPtrVector &args);

friend ArgumentPtr parseExpr(const std::string &str);
Expand All @@ -155,7 +137,11 @@ class Expression : public IExpressionCRTP<Expression> {
static Parser::Map<std::unique_ptr<IMathObject>, const ArgumentPtrVector &> &getExpressionMakers();

private:
ArgumentPtr child;
mutable ArgumentPtr child;

mutable std::string stringCached;

mutable bool isSimplified = false;
};

ArgumentPtr parseExpr(const std::string &str);
Expand Down Expand Up @@ -188,4 +174,31 @@ Expression operator/(const Expression &lhs, const Variable &rhs);

Expression operator/(const Variable &lhs, const Expression &rhs);

template <typename Function, bool isPolynomial>
inline void Expression::registerFunctionExpressionMaker(auto &&maker) {
Parser::Function<std::unique_ptr<IMathObject>, const ArgumentPtrVector &> constructor =
[maker = std::forward<decltype(maker)>(maker)](const ArgumentPtrVector &args) {
static const IFunction::Type type = Function().getFunctionType();
std::unique_ptr<IMathObject> res;

if constexpr (IsFunctionTypeAny<Function>::value) {
res = maker(args);
}
else if constexpr (isPolynomial) {
if (size_t(type) <= args.size()) {
res = maker(args);
}
}
else {
if (size_t(type) == args.size()) {
res = maker(args);
}
}

return res;
};

Parser::add<Function>(getExpressionMakers(), std::move(constructor));
}

}
4 changes: 1 addition & 3 deletions include/fintamath/expressions/IExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ class IExpression : public IArithmetic {

virtual ArgumentPtr postSimplify() const;

virtual ArgumentPtr preciseSimplify() const = 0;

static void compressChild(ArgumentPtr &child);
virtual ArgumentPtr preciseSimplify() const;

static void simplifyChild(ArgumentPtr &child);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@ class IBinaryExpressionCRTP : public IBinaryExpressionBaseCRTP<Derived, isMultiF
public:
explicit IBinaryExpressionCRTP(const IFunction &inFunc, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
this->func = cast<IFunction>(inFunc.clone());

this->lhsChild = lhs;
this->compressChild(this->lhsChild);

this->rhsChild = rhs;
this->compressChild(this->rhsChild);
}
};

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ class IPolynomExpressionCRTP : public IPolynomExpressionBaseCRTP<Derived, isMult

void addElement(const ArgumentPtr &element) final {
ArgumentPtr elem = element;
this->compressChild(elem);

ArgumentPtrVector elemPolynom;

Expand Down
2 changes: 0 additions & 2 deletions include/fintamath/expressions/interfaces/IUnaryExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ class IUnaryExpressionCRTP : public IUnaryExpressionBaseCRTP<Derived, isMultiFun
public:
explicit IUnaryExpressionCRTP(const IFunction &inFunc, const ArgumentPtr &arg) {
this->func = cast<IFunction>(inFunc.clone());

this->child = arg;
this->compressChild(this->child);
}
};

Expand Down
86 changes: 56 additions & 30 deletions src/fintamath/expressions/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,30 @@ Expression::Expression() : child(Integer(0).clone()) {
}

Expression::Expression(const std::string &str) : child(fintamath::parseExpr(str)) {
simplifyChild(child);
}

Expression::Expression(const ArgumentPtr &obj) {
if (auto expr = cast<Expression>(obj)) {
child = expr->child;
}
else {
child = obj;
simplifyChild(child);
}
Expression::Expression(const ArgumentPtr &obj) : child(compress(obj)) {
}

Expression::Expression(const IMathObject &obj) : Expression(obj.toMinimalObject()) {
Expression::Expression(const IMathObject &obj) : Expression(obj.clone()) {
}

Expression::Expression(int64_t val) : child(Integer(val).clone()) {
}

std::string Expression::toString() const {
return child->toString();
simplifyMutable();
return stringCached;
}

Expression Expression::precise(uint8_t precision) const {
// TODO: rework so that small ints don't convert to reals
// TODO: move this transfer to the approximation function
simplifyMutable();
Expression preciseExpr(preciseSimplify());
preciseExpr.simplifyMutable();
preciseRec(preciseExpr.child, precision);
preciseExpr.updateStringMutable();
return preciseExpr;
}

Expand Down Expand Up @@ -132,7 +130,7 @@ bool Expression::parseOperator(const TermVector &terms, size_t start, size_t end
child = makeExpr(*foundOper, lhsArg, rhsArg);
}

compressChild(child);
child = compress(child);

return true;
}
Expand Down Expand Up @@ -234,31 +232,37 @@ std::shared_ptr<IFunction> Expression::getFunction() const {
}

Expression &Expression::add(const Expression &rhs) {
child = makeExpr(Add(), *child, *rhs.child)->toMinimalObject();
child = makeExpr(Add(), *child, *rhs.child);
isSimplified = false;
return *this;
}

Expression &Expression::substract(const Expression &rhs) {
child = makeExpr(Sub(), *child, *rhs.child)->toMinimalObject();
child = makeExpr(Sub(), *child, *rhs.child);
isSimplified = false;
return *this;
}

Expression &Expression::multiply(const Expression &rhs) {
child = makeExpr(Mul(), *child, *rhs.child)->toMinimalObject();
child = makeExpr(Mul(), *child, *rhs.child);
isSimplified = false;
return *this;
}

Expression &Expression::divide(const Expression &rhs) {
child = makeExpr(Div(), *child, *rhs.child)->toMinimalObject();
child = makeExpr(Div(), *child, *rhs.child);
isSimplified = false;
return *this;
}

Expression &Expression::negate() {
child = makeExpr(Neg(), *child)->toMinimalObject();
child = makeExpr(Neg(), *child);
isSimplified = false;
return *this;
}

ArgumentPtrVector Expression::getChildren() const {
simplifyMutable();
return {child};
}

Expand Down Expand Up @@ -288,6 +292,18 @@ ArgumentPtr Expression::preciseSimplify() const {
return preciseChild;
}

void Expression::simplifyMutable() const {
if (!isSimplified) {
simplifyChild(child);
isSimplified = true;
updateStringMutable();
}
}

void Expression::updateStringMutable() const {
stringCached = child->toString();
}

TermVector Expression::tokensToTerms(const TokenVector &tokens) {
if (tokens.empty()) {
return {};
Expand Down Expand Up @@ -502,14 +518,25 @@ void Expression::preciseRec(ArgumentPtr &arg, uint8_t precision) {
}
}

ArgumentPtr Expression::compress(const ArgumentPtr &child) {
if (const auto expr = cast<Expression>(child)) {
return expr->child;
}

return child;
}

std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentPtrVector &args) {
Expression::validateFunctionArgs(func, args);
auto argsView = args | std::views::transform(&Expression::compress);
ArgumentPtrVector compressedArgs(argsView.begin(), argsView.end());

if (auto expr = Parser::parse(Expression::getExpressionMakers(), func.toString(), args)) {
Expression::validateFunctionArgs(func, compressedArgs);

if (auto expr = Parser::parse(Expression::getExpressionMakers(), func.toString(), compressedArgs)) {
return expr;
}

return FunctionExpression(func, args).clone();
return FunctionExpression(func, compressedArgs).clone();
}

std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentRefVector &args) {
Expand All @@ -535,8 +562,7 @@ void Expression::validateFunctionArgs(const IFunction &func, const ArgumentPtrVe
expectedType = expectedArgTypes[i];
}

ArgumentPtr arg = args[i];
compressChild(arg);
const ArgumentPtr &arg = args[i];

if (!doesArgMatch(expectedType, arg)) {
throw InvalidInputFunctionException(func.toString(), argumentVectorToStringVector(args));
Expand Down Expand Up @@ -579,46 +605,46 @@ Expression operator+(const Variable &lhs, const Variable &rhs) {
}

Expression operator+(const Expression &lhs, const Variable &rhs) {
return Expression(addExpr(lhs.getChildren().front(), rhs.clone()));
return Expression(addExpr(lhs, rhs));
}

Expression operator+(const Variable &lhs, const Expression &rhs) {
return Expression(addExpr(lhs.clone(), rhs.getChildren().front()));
return Expression(addExpr(lhs, rhs));
}

Expression operator-(const Variable &lhs, const Variable &rhs) {
return Expression(subExpr(lhs, rhs));
}

Expression operator-(const Expression &lhs, const Variable &rhs) {
return Expression(subExpr(lhs.getChildren().front(), rhs.clone()));
return Expression(subExpr(lhs, rhs));
}

Expression operator-(const Variable &lhs, const Expression &rhs) {
return Expression(subExpr(lhs.clone(), rhs.getChildren().front()));
return Expression(subExpr(lhs, rhs));
}

Expression operator*(const Variable &lhs, const Variable &rhs) {
return Expression(mulExpr(lhs, rhs));
}

Expression operator*(const Expression &lhs, const Variable &rhs) {
return Expression(mulExpr(lhs.getChildren().front(), rhs.clone()));
return Expression(mulExpr(lhs, rhs));
}

Expression operator*(const Variable &lhs, const Expression &rhs) {
return Expression(mulExpr(lhs.clone(), rhs.getChildren().front()));
return Expression(mulExpr(lhs, rhs));
}

Expression operator/(const Variable &lhs, const Variable &rhs) {
return Expression(divExpr(lhs, rhs));
}

Expression operator/(const Expression &lhs, const Variable &rhs) {
return Expression(divExpr(lhs.getChildren().front(), rhs.clone()));
return Expression(divExpr(lhs, rhs));
}

Expression operator/(const Variable &lhs, const Expression &rhs) {
return Expression(divExpr(lhs.clone(), rhs.getChildren().front()));
return Expression(divExpr(lhs, rhs));
}
}
Loading

0 comments on commit adfb065

Please sign in to comment.