Skip to content

Commit

Permalink
reification + implications + correct transformations + cover cases fo…
Browse files Browse the repository at this point in the history
…r choco sytax
  • Loading branch information
Dimosts committed Sep 18, 2023
1 parent 15d5cd2 commit e7c29f5
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions cpmpy/solvers/choco.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..transformations.flatten_model import flatten_constraint, flatten_objective
from ..transformations.comparison import only_numexpr_equality
from ..transformations.linearize import canonical_comparison
from ..transformations.reification import only_bv_reifies, only_bv_implies, reify_rewrite


class CPM_choco(SolverInterface):
Expand Down Expand Up @@ -308,10 +309,15 @@ def transform(self, cpm_expr):
cpm_cons = toplevel_list(cpm_expr)
supported = {"min", "max", "abs", "count", "element", "alldifferent", "alldifferent_except0", "allequal",
"table", "InDomain", "cumulative", "circuit", "gcc", "inverse"}
cpm_cons = decompose_in_tree(cpm_cons, supported)
cpm_cons = flatten_constraint(cpm_cons) # flat normal form
supported_reified = {"alldifferent", "alldifferent_except0", "allequal",
"table", "InDomain", "cumulative", "circuit", "gcc", "inverse"}
cpm_cons = decompose_in_tree(cpm_cons, supported, supported_reified)
cpm_cons = canonical_comparison(cpm_cons)
cpm_cons = flatten_constraint(cpm_cons) # flat normal form
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(["sum", "wsum", "alldifferent", "alldifferent_except0", "allequal",
"table", "InDomain", "cumulative", "circuit", "gcc", "inverse"])) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"])) # support >, <, !=
cpm_cons = only_bv_reifies(cpm_cons)

return cpm_cons

Expand All @@ -338,7 +344,9 @@ def __add__(self, cpm_expr):

# transform and post the constraints
for con in self.transform(cpm_expr):
self._get_constraint(con).post()
c = self._get_constraint(con)
if c is not None:
c.post()

return self

Expand All @@ -361,17 +369,21 @@ def _get_constraint(self, cpm_expr):
elif cpm_expr.name == 'or':
return self.chc_model.or_(self.solver_vars(cpm_expr.args))
elif cpm_expr.name == '->':
assert (isinstance(cpm_expr.args[0], _BoolVarImpl)) # lhs must be boolvar
lhs = self.solver_var(cpm_expr.args[0])
lhs = self.solver_var(cpm_expr.args[0]) # should always be boolvar due to only_bv_reifies
# right hand side
if isinstance(cpm_expr.args[1], _BoolVarImpl):
# bv -> bv
# PyChoco does not have "implies" constraint
return self.chc_model.arithm(lhs, "<=", self.solver_var(cpm_expr.args[1]))
return self.chc_model.or_([~lhs, self.solver_var(cpm_expr.args[1])])
else:
# bv -> boolexpr
# the `reify_rewrite()` transformation ensures that only reifiable rhs remain here
bv = self._get_constraint(cpm_expr.args[1]).reify()
return self.chc_model.arithm(lhs, "<=", bv)
if cpm_expr.args[1].name == 'not':
bv = self._get_constraint(cpm_expr.args[1].args[0]).reify()
return self.chc_model.or_([~lhs, ~bv])
else:
bv = self._get_constraint(cpm_expr.args[1]).reify()
return self.chc_model.or_([~lhs, bv])
else:
raise NotImplementedError("Not a known supported Choco Operator '{}' {}".format(
cpm_expr.name, cpm_expr))
Expand All @@ -381,20 +393,38 @@ def _get_constraint(self, cpm_expr):
elif isinstance(cpm_expr, Comparison):
lhs = cpm_expr.args[0]
rhs = cpm_expr.args[1]
chcrhs = self.solver_var(cpm_expr.args[1])

if isinstance(lhs, _NumVarImpl) or isinstance(lhs, Operator) and (
lhs.name == 'sum' or lhs.name == 'wsum' or lhs.name == "sub"):
if lhs.is_bool() and rhs.is_bool(): #boolean equality -- Reification
if isinstance(rhs, _NumVarImpl):
return self.chc_model.all_equal(self.solver_vars([lhs, rhs]))
else:
bv = self._get_constraint(rhs).reify()
return self.chc_model.all_equal([self.solver_var(lhs), bv])
elif isinstance(lhs, _NumVarImpl) or (isinstance(lhs, Operator) and (
lhs.name == 'sum' or lhs.name == 'wsum' or lhs.name == "sub")):
# a BoundedLinearExpression LHS, special case, like in objective
chc_numexpr = self._make_numexpr(cpm_expr)
return chc_numexpr
elif cpm_expr.name == '==':
chcrhs = self.solver_var(rhs)
# NumExpr == IV, supported by Choco (thanks to `only_numexpr_equality()` transformation)
if lhs.name == 'min':
if isinstance(rhs, int): # Choco does not accept an int in rhs
chcrhs = self.chc_model.intvar(rhs, rhs) # convert to "variable"
elif not isinstance(rhs, _NumVarImpl):
raise Exception(f"Choco cannot accept min operation equal to: {rhs}")
return self.chc_model.min(chcrhs, self.solver_vars(lhs.args))
elif lhs.name == 'max':
if isinstance(rhs, int): # Choco does not accept an int in rhs
chcrhs = self.chc_model.intvar(rhs, rhs) # convert to "variable"
elif not isinstance(rhs, _NumVarImpl):
raise Exception(f"Choco cannot accept max operation equal to: {rhs}")
return self.chc_model.max(chcrhs, self.solver_vars(lhs.args))
elif lhs.name == 'abs':
if isinstance(rhs, int): # Choco does not accept an int in rhs
chcrhs = self.chc_model.intvar(rhs, rhs) # convert to "variable"
elif not isinstance(rhs, _NumVarImpl):
raise Exception(f"Choco cannot accept absolute operation equal to: {rhs}")
return self.chc_model.absolute(chcrhs, self.solver_var(lhs.args[0]))
elif lhs.name == 'count':
arr, val = self.solver_vars(lhs)
Expand All @@ -414,15 +444,15 @@ def _get_constraint(self, cpm_expr):
if isinstance(rhs, int):
result = self.chc_model.intvar(rhs, rhs) # convert to "variable"
elif isinstance(rhs, _NumVarImpl):
result = chcrhs # use variable
result = self.solver_var(rhs) # use variable
else:
raise Exception(f"Cannot accept division with the result being: {rhs}")
return self.chc_model.div(self.solver_var(lhs.args[0]), divisor, result)
elif lhs.name == 'element':
if isinstance(rhs, int):
result = self.chc_model.intvar(rhs, rhs) # convert to "variable"
elif isinstance(rhs, _NumVarImpl):
result = chcrhs # use variable
result = self.solver_var(rhs) # use variable
else:
raise Exception(f"Cannot accept the right hand side of the element constraint being: {rhs}")
return self.chc_model.element(result, self.solver_vars(lhs.args[0]),
Expand All @@ -435,6 +465,10 @@ def _get_constraint(self, cpm_expr):
rhs = self.solver_vars(rhs) # get choco variable
return self.chc_model.mod(self.solver_var(lhs.args[0]), self.solver_var(divisor), rhs)
elif lhs.name == 'pow':
if isinstance(rhs, int): # Choco does not accept an int in rhs
chcrhs = self.chc_model.intvar(rhs, rhs) # convert to "variable"
elif not isinstance(rhs, _NumVarImpl):
raise Exception(f"Choco cannot accept power operation equal to: {rhs}")
return self.chc_model.pow(self.solver_vars(lhs.args[0]), self.solver_vars(lhs.args[1]),
chcrhs)
raise NotImplementedError(
Expand All @@ -448,29 +482,23 @@ def _get_constraint(self, cpm_expr):
for i in range(len(vars)):
if isinstance(vars[i], int):
vars[i] = self.chc_model.intvar(vars[i], vars[i]) # convert to "variable"
elif isinstance(vars[i], IntVar):
vars[i] = vars[i] # use variable
else:
elif not isinstance(vars[i], IntVar):
raise Exception(f"Choco cannot accept alldifferent with: {vars[i]}")
return self.chc_model.all_different(vars)
elif cpm_expr.name == 'alldifferent_except0':
vars = self.solver_vars(cpm_expr.args)
for i in range(len(vars)):
if isinstance(vars[i], int):
vars[i] = self.chc_model.intvar(vars[i], vars[i]) # convert to "variable"
elif isinstance(vars[i], IntVar):
vars[i] = vars[i] # use variable
else:
elif not isinstance(vars[i], IntVar):
raise Exception(f"Choco cannot accept alldifferent_except0 with: {vars[i]}")
return self.chc_model.all_different_except_0(vars)
elif cpm_expr.name == 'allequal':
vars = self.solver_vars(cpm_expr.args)
for i in range(len(vars)):
if isinstance(vars[i], int):
vars[i] = self.chc_model.intvar(vars[i], vars[i]) # convert to "variable"
elif isinstance(vars[i], IntVar):
vars[i] = vars[i] # use variable
else:
elif not isinstance(vars[i], IntVar):
raise Exception(f"Choco cannot accept allequal with: {vars[i]}")
return self.chc_model.all_equal(vars)
elif cpm_expr.name == 'table':
Expand Down

0 comments on commit e7c29f5

Please sign in to comment.