Skip to content

Commit

Permalink
add is_leaf and has_nested to expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout4 committed May 17, 2024
1 parent b3b30c1 commit d79e5f4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
11 changes: 10 additions & 1 deletion cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
import numpy as np


from .utils import is_num, is_any_list, flatlist, argval, get_bounds, is_boolexpr, is_true_cst, is_false_cst
from .utils import is_num, is_any_list, flatlist, argval, get_bounds, is_boolexpr, is_true_cst, is_false_cst, is_leaf
from ..exceptions import IncompleteFunctionError, TypeError


Expand Down Expand Up @@ -135,6 +135,12 @@ def __repr__(self):
def __hash__(self):
return hash(self.__repr__())

def has_nested(self):
return not all([is_leaf(x) for x in self.args])

def is_leaf(self):
return False # default

def is_bool(self):
""" is it a Boolean (return type) Operator?
Default: yes
Expand Down Expand Up @@ -381,6 +387,9 @@ def __bool__(self):
"""Called to implement truth value testing and the built-in operation bool(), return stored value"""
return self.args[0]

def is_leaf(self):
return True


class Comparison(Expression):
"""Represents a comparison between two sub-expressions
Expand Down
4 changes: 4 additions & 0 deletions cpmpy/expressions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def argval(a):
raise e


def is_leaf(a):
return a.is_leaf() if hasattr(a, 'is_leaf') else True


def eval_comparison(str_op, lhs, rhs):
"""
Internal function: evaluates the textual `str_op` comparison operator
Expand Down
8 changes: 7 additions & 1 deletion cpmpy/expressions/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

import numpy as np
from .core import Expression, Operator
from .utils import is_num, is_int, flatlist, is_boolexpr, is_true_cst, is_false_cst, get_bounds
from .utils import is_num, is_int, flatlist, is_boolexpr, is_true_cst, is_false_cst, get_bounds, is_leaf


def BoolVar(shape=1, name=None):
Expand Down Expand Up @@ -245,6 +245,9 @@ def is_bool(self):
"""
return False

def is_leaf(self):
return True

def value(self):
""" the value obtained in the last solve call
(or 'None')
Expand Down Expand Up @@ -385,6 +388,9 @@ def is_bool(self):
"""
return False

def is_leaf(self):
return all([is_leaf(x) for x in self])

def value(self):
""" the values, for each of the stored variables, obtained in the last solve call
(or 'None')
Expand Down

0 comments on commit d79e5f4

Please sign in to comment.