Skip to content

Commit

Permalink
New Expression validation
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Sep 25, 2023
1 parent 01f0591 commit 9916b1f
Show file tree
Hide file tree
Showing 17 changed files with 105 additions and 189 deletions.
10 changes: 2 additions & 8 deletions include/fintamath/expressions/Expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,12 @@ class Expression : public IExpressionCRTP<Expression> {

static bool isNonOperatorFunction(const ArgumentPtr &val);

static void validateChild(const ArgumentPtr &inChild);
static void validateFunctionArgs(const IFunction &func, const ArgumentPtrVector &args);

static void validateFunctionArgs(const std::shared_ptr<IFunction> &func, const ArgumentPtrVector &args);
static bool doesArgMatch(const MathObjectType &expectedType, const ArgumentPtr &arg);

static void preciseRec(ArgumentPtr &arg, uint8_t precision);

friend std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentPtrVector &args);

friend std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentRefVector &args);

friend std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentPtrVector &args);

friend std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentPtrVector &args);

friend ArgumentPtr parseExpr(const std::string &str);
Expand Down
8 changes: 0 additions & 8 deletions include/fintamath/functions/FunctionUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ class IFunction;

extern bool isExpression(const IMathObject &arg);

extern std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentPtrVector &args);

extern std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentRefVector &args);

extern std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentPtrVector &args);

extern std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentRefVector &args);
Expand All @@ -35,10 +31,6 @@ ArgumentPtr toArgumentPtr(T &arg) {
}
}

std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const std::derived_from<IMathObject> auto &...args) {
return makeExprChecked(func, ArgumentRefVector{args...});
}

std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const std::derived_from<IMathObject> auto &...args) {
return makeExpr(func, ArgumentPtrVector{args.clone()...});
}
Expand Down
4 changes: 2 additions & 2 deletions include/fintamath/functions/IFunction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class IFunction : public IMathObject {
public:
virtual IFunction::Type getFunctionType() const = 0;

virtual size_t getReturnType() const = 0;
virtual MathObjectType getReturnType() const = 0;

virtual ArgumentTypeVector getArgType() const = 0;
virtual ArgumentTypeVector getArgTypes() const = 0;

virtual bool doArgsMatch(const ArgumentRefVector &argsVect) const = 0;

Expand Down
16 changes: 8 additions & 8 deletions include/fintamath/functions/IFunctionCRTP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class IFunctionCRTP_ : public IFunction {
return type;
}

size_t getReturnType() const final {
return size_t(Return::getTypeStatic());
MathObjectType getReturnType() const final {
return Return::getTypeStatic();
}

ArgumentTypeVector getArgType() const final {
ArgumentTypeVector getArgTypes() const final {
ArgumentTypeVector argTypes;
getArgType<0, Args...>(argTypes);
getArgTypes<0, Args...>(argTypes);
return argTypes;
}

Expand Down Expand Up @@ -63,18 +63,18 @@ class IFunctionCRTP_ : public IFunction {
return makeExpr(*this, argsVect);
}

return makeExprChecked(*this, argsVect);
return makeExpr(*this, argsVect)->toMinimalObject();
}

private:
template <size_t i, typename Head, typename... Tail>
void getArgType(ArgumentTypeVector &outArgsTypes) const {
void getArgTypes(ArgumentTypeVector &outArgsTypes) const {
outArgsTypes.emplace_back(Head::getTypeStatic());
getArgType<i + 1, Tail...>(outArgsTypes);
getArgTypes<i + 1, Tail...>(outArgsTypes);
}

template <size_t>
void getArgType(ArgumentTypeVector & /*outArgTypes*/) const {
void getArgTypes(ArgumentTypeVector & /*outArgTypes*/) const {
// The end of unpacking.
}

Expand Down
2 changes: 1 addition & 1 deletion include/fintamath/literals/constants/IConstant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace fintamath {

class IConstant : public ILiteral {
public:
virtual size_t getReturnType() const = 0;
virtual MathObjectType getReturnType() const = 0;

std::unique_ptr<IMathObject> operator()() const {
return call();
Expand Down
2 changes: 1 addition & 1 deletion include/fintamath/literals/constants/IConstantCRTP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class IConstantCRTP_ : public IConstant {
#undef I_LITERAL_CRTP

public:
size_t getReturnType() const final {
MathObjectType getReturnType() const final {
return Return::getTypeStatic();
}

Expand Down
168 changes: 74 additions & 94 deletions src/fintamath/expressions/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ Expression::Expression() : child(Integer(0).clone()) {
}

Expression::Expression(const std::string &str) : child(fintamath::parseExpr(str)) {
validateChild(child);
simplifyChild(child);
}

Expand All @@ -29,7 +28,6 @@ Expression::Expression(const ArgumentPtr &obj) {
}
else {
child = obj;
validateChild(child);
simplifyChild(child);
}
}
Expand Down Expand Up @@ -236,34 +234,50 @@ std::shared_ptr<IFunction> Expression::getFunction() const {
}

Expression &Expression::add(const Expression &rhs) {
child = makeExprChecked(Add(), *child, *rhs.child);
child = makeExpr(Add(), *child, *rhs.child)->toMinimalObject();
return *this;
}

Expression &Expression::substract(const Expression &rhs) {
child = makeExprChecked(Sub(), *child, *rhs.child);
child = makeExpr(Sub(), *child, *rhs.child)->toMinimalObject();
return *this;
}

Expression &Expression::multiply(const Expression &rhs) {
child = makeExprChecked(Mul(), *child, *rhs.child);
child = makeExpr(Mul(), *child, *rhs.child)->toMinimalObject();
return *this;
}

Expression &Expression::divide(const Expression &rhs) {
child = makeExprChecked(Div(), *child, *rhs.child);
child = makeExpr(Div(), *child, *rhs.child)->toMinimalObject();
return *this;
}

Expression &Expression::negate() {
child = makeExprChecked(Neg(), *child);
child = makeExpr(Neg(), *child)->toMinimalObject();
return *this;
}

ArgumentPtrVector Expression::getChildren() const {
return {child};
}

void Expression::setChildren(const ArgumentPtrVector &childVect) {
if (childVect.size() != 1) {
throw InvalidInputFunctionException("", argumentVectorToStringVector(childVect));
}

*this = Expression(childVect.front());
}

void Expression::setVariables(const std::vector<std::pair<Variable, ArgumentPtr>> &varsToVals) {
IExpression::setVariables(varsToVals);
}

void Expression::setVariable(const Variable &var, const Expression &val) {
setVariables({{var, val.child}});
}

ArgumentPtr Expression::simplify() const {
return child;
}
Expand Down Expand Up @@ -466,75 +480,6 @@ bool Expression::isNonOperatorFunction(const ArgumentPtr &val) {
return is<IFunction>(val) && !is<IOperator>(val);
}

void Expression::validateChild(const ArgumentPtr &inChild) {
const auto childExpr = cast<IExpression>(inChild);

if (!childExpr) {
return;
}

const std::shared_ptr<IFunction> func = childExpr->getFunction();
const ArgumentPtrVector children = childExpr->getChildren();

if (func->getFunctionType() == IFunction::Type::Any || children.size() <= size_t(func->getFunctionType())) {
validateFunctionArgs(func, children);
}
else {
for (auto i : std::views::iota(0U, children.size() - 1)) {
for (auto j : std::views::iota(i + 1, children.size())) {
validateFunctionArgs(func, {children[i], children[j]});
}
}
}

for (const auto &arg : children) {
validateChild(arg);
}
}

void Expression::validateFunctionArgs(const std::shared_ptr<IFunction> &func, const ArgumentPtrVector &args) {
if (func->getFunctionType() == IFunction::Type::Any && args.empty()) {
throw InvalidInputFunctionException(func->toString(), argumentVectorToStringVector(args));
}

ArgumentTypeVector childrenTypes = func->getArgType();

if (func->getFunctionType() == IFunction::Type::Any) {
childrenTypes = ArgumentTypeVector(args.size(), childrenTypes.front());
}

for (auto i : std::views::iota(0U, args.size())) {
const ArgumentPtr &arg = args[i];
const MathObjectType Type = childrenTypes[i];

if (const auto childExpr = cast<IExpression>(arg)) {
const std::shared_ptr<IFunction> childFunc = childExpr->getFunction();
const MathObjectType childType = childFunc->getReturnType();

if (childType != Variable::getTypeStatic() &&
!isBaseOf(Type, childType) &&
!isBaseOf(childType, Type)) {

throw InvalidInputFunctionException(func->toString(), argumentVectorToStringVector(args));
}
}
else if (const auto childConst = cast<IConstant>(arg)) {
const MathObjectType childType = childConst->getReturnType();

if (!isBaseOf(Type, childType) && !isBaseOf(childType, Type)) {
throw InvalidInputFunctionException(func->toString(), argumentVectorToStringVector(args));
}
}
else {
MathObjectType childType = arg->getType();

if (childType != Variable::getTypeStatic() && !isBaseOf(Type, childType)) {
throw InvalidInputFunctionException(func->toString(), argumentVectorToStringVector(args));
}
}
}
}

void Expression::preciseRec(ArgumentPtr &arg, uint8_t precision) {
if (const auto realArg = cast<Real>(arg)) {
arg = realArg->precise(precision).clone();
Expand All @@ -557,16 +502,9 @@ void Expression::preciseRec(ArgumentPtr &arg, uint8_t precision) {
}
}

std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentPtrVector &args) {
Expression res(makeExpr(func, args));
return res.getChildren().front()->clone();
}

std::unique_ptr<IMathObject> makeExprChecked(const IFunction &func, const ArgumentRefVector &args) {
return makeExprChecked(func, argumentRefVectorToArgumentPtrVector(args));
}

std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentPtrVector &args) {
Expression::validateFunctionArgs(func, args);

if (auto expr = Parser::parse(Expression::getExpressionMakers(), func.toString(), args)) {
return expr;
}
Expand All @@ -578,20 +516,62 @@ std::unique_ptr<IMathObject> makeExpr(const IFunction &func, const ArgumentRefVe
return makeExpr(func, argumentRefVectorToArgumentPtrVector(args));
}

void Expression::setChildren(const ArgumentPtrVector &childVect) {
if (childVect.size() != 1) {
throw InvalidInputFunctionException("", argumentVectorToStringVector(childVect));
void Expression::validateFunctionArgs(const IFunction &func, const ArgumentPtrVector &args) {
IFunction::Type funcType = func.getFunctionType();

if ((funcType != IFunction::Type::None && args.empty()) ||
(funcType != IFunction::Type::Any && args.size() < size_t(funcType))) {

throw InvalidInputFunctionException(func.toString(), argumentVectorToStringVector(args));
}

*this = Expression(childVect.front());
}
bool doesArgSizeMatch = funcType != IFunction::Type::Any && args.size() == size_t(funcType);

void Expression::setVariables(const std::vector<std::pair<Variable, ArgumentPtr>> &varsToVals) {
IExpression::setVariables(varsToVals);
ArgumentTypeVector expectedArgTypes = func.getArgTypes();
MathObjectType expectedType = expectedArgTypes.front();

for (auto i : std::views::iota(0U, args.size())) {
if (doesArgSizeMatch) {
expectedType = expectedArgTypes[i];
}

ArgumentPtr arg = args[i];
compressChild(arg);

if (!doesArgMatch(expectedType, arg)) {
throw InvalidInputFunctionException(func.toString(), argumentVectorToStringVector(args));
}
}
}

void Expression::setVariable(const Variable &var, const Expression &val) {
setVariables({{var, val.child}});
bool Expression::doesArgMatch(const MathObjectType &expectedType, const ArgumentPtr &arg) {
if (const auto childExpr = cast<IExpression>(arg)) {
const std::shared_ptr<IFunction> childFunc = childExpr->getFunction();
const MathObjectType childType = childFunc->getReturnType();

if (childType != Variable::getTypeStatic() &&
!isBaseOf(expectedType, childType) &&
!isBaseOf(childType, expectedType)) {

return false;
}
}
else if (const auto childConst = cast<IConstant>(arg)) {
const MathObjectType childType = childConst->getReturnType();

if (!isBaseOf(expectedType, childType) && !isBaseOf(childType, expectedType)) {
return false;
}
}
else {
MathObjectType childType = arg->getType();

if (childType != Variable::getTypeStatic() && !isBaseOf(expectedType, childType)) {
return false;
}
}

return true;
}

Expression operator+(const Variable &lhs, const Variable &rhs) {
Expand Down
2 changes: 1 addition & 1 deletion src/fintamath/expressions/binary/CompExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ ArgumentPtr CompExpression::coeffSimplify(const IFunction &func, const ArgumentP
child = divExpr(child, dividerNum);
}

ArgumentPtr newLhs = addExpr(dividendPolynom);
ArgumentPtr newLhs = dividendPolynom.size() > 1 ? addExpr(dividendPolynom) : dividendPolynom.front();

if (*dividerNum < Integer(0)) {
return makeExpr(*cast<IFunction>(getOppositeFunction(func)), newLhs, rhs);
Expand Down
10 changes: 5 additions & 5 deletions src/fintamath/expressions/binary/DivExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,7 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::sumSumSimplify(const Argument
return {};
}

ArgumentPtr result = addExpr(resultVect);

ArgumentPtr result = resultVect.size() > 1 ? addExpr(resultVect) : resultVect.front();
ArgumentPtr remainderAdd = addExpr(remainderVect);
ArgumentPtr remainder = divExpr(remainderAdd, rhs);
simplifyChild(remainder);
Expand Down Expand Up @@ -330,11 +329,11 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::sumMulSimplify(const Argument
return {};
}

ArgumentPtr result = addExpr(resultChildren);
ArgumentPtr result = resultChildren.size() > 1 ? addExpr(resultChildren) : resultChildren.front();

ArgumentPtr remainder;
if (!remainderChildren.empty()) {
ArgumentPtr remainderAdd = addExpr(remainderChildren);
ArgumentPtr remainderAdd = remainderChildren.size() > 1 ? addExpr(remainderChildren) : remainderChildren.front();
remainder = divExpr(remainderAdd, rhs);
simplifyChild(remainder);
}
Expand Down Expand Up @@ -368,7 +367,8 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const Argument
multiplicator.emplace_back(mulExpr(rhsChildren[i], result));
}

ArgumentPtr remainderNegAdd = negExpr(addExpr(multiplicator));
ArgumentPtr remainderAdd = multiplicator.size() > 1 ? addExpr(multiplicator) : multiplicator.front();
ArgumentPtr remainderNegAdd = negExpr(remainderAdd);
simplifyChild(remainderNegAdd);
ArgumentPtr remainder = divExpr(remainderNegAdd, rhs);

Expand Down
Loading

0 comments on commit 9916b1f

Please sign in to comment.