Skip to content

Commit

Permalink
remove of only_bv_implies + add indomain + remove unneeded stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Dimosts committed Sep 15, 2023
1 parent 74f744b commit ecf379b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
41 changes: 13 additions & 28 deletions cpmpy/solvers/choco.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ def solve(self, time_limit=None, **kwargs):
- time_limit: maximum solve time in seconds (float, optional)
- kwargs: any keyword argument, sets parameters of solver object
Additional keyword arguments:
The ortools solver parameters are defined in its 'sat_parameters.proto' description:
https://github.com/google/or-tools/blob/stable/ortools/sat/sat_parameters.proto
Arguments that correspond to solver parameters:
<Please document key solver arguments that the user might wish to change
for example: log_output=True, var_ordering=3, num_cores=8, ...>
<Add link to documentation of all solver parameters>
"""

if time_limit is not None:
Expand Down Expand Up @@ -162,17 +154,15 @@ def solve(self, time_limit=None, **kwargs):

return has_sol

def solveAll(self, display=None, time_limit=None, solution_limit=None, call_from_model=False, **kwargs):
def solveAll(self, display=None, time_limit=None, solution_limit=None, **kwargs):
"""
A shorthand to (efficiently) compute all solutions, map them to CPMpy and optionally display the solutions.
It is just a wrapper around the use of `OrtSolutionPrinter()` in fact.
Compute all (optimal) solutions, map them to CPMpy and optionally display the solutions.
Arguments:
- display: either a list of CPMpy expressions, OR a callback function, called with the variables after value-mapping
default/None: nothing displayed
- solution_limit: stop after this many solutions (default: None)
- call_from_model: whether the method is called from a CPMpy Model instance or not
- time_limit: maximum solve time in seconds (float, default: None)
Returns: number of solutions found
"""
Expand All @@ -194,10 +184,6 @@ def solveAll(self, display=None, time_limit=None, solution_limit=None, call_from
self.cpm_status = SolverStatus(self.name)
self.cpm_status.runtime = end - start

# cb = OrtSolutionPrinter(self, display=display, solution_limit=solution_limit)
# self.solve(enumerate_all_solutions=True, solution_callback=cb, time_limit=time_limit, **kwargs)
# return cb.solution_count()

# display if needed
if display is not None:
for sol in sols:
Expand Down Expand Up @@ -277,7 +263,7 @@ def _make_numexpr(self, cpm_expr):
Used especially to post an expression as objective function
Accepted by ORTools:
Accepted by Choco:
- Decision variable: Var
- Linear: sum([Var]) (CPMpy class 'Operator', name 'sum')
wsum([Const],[Var]) (CPMpy class 'Operator', name 'wsum')
Expand Down Expand Up @@ -326,14 +312,12 @@ def transform(self, cpm_expr):

cpm_cons = toplevel_list(cpm_expr)
supported = {"min", "max", "abs", "count", "element", "alldifferent", "alldifferent_except0", "allequal",
"table", "cumulative", "circuit", "gcc", "inverse"}
"table", "InDomain", "cumulative", "circuit", "gcc", "inverse"}
cpm_cons = decompose_in_tree(cpm_cons, supported)
cpm_cons = flatten_constraint(cpm_cons) # flat normal form
cpm_cons = canonical_comparison(cpm_cons)
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"])) # support >, <, !=
cpm_cons = only_bv_implies(cpm_cons) # everything that can create
# reified expr must go before this

return cpm_cons

Expand Down Expand Up @@ -375,15 +359,10 @@ def _post_constraint(self, cpm_expr):
What 'supported' means depends on the solver capabilities, and in effect on what transformations
are applied in `transform()`.
Returns the posted ortools 'Constraint', so that it can be used in reification
e.g. self._post_constraint(smth, reifiable=True).onlyEnforceIf(self.solver_var(bvar))
:param cpm_expr: CPMpy expression
:type cpm_expr: Expression
:param reifiable: if True, will throw an error if cpm_expr can not be reified by ortools (for safety)
"""
import pychoco as chc

# Operators: base (bool), lhs=numexpr, lhs|rhs=boolexpr (reified ->)
if isinstance(cpm_expr, Operator):
Expand Down Expand Up @@ -450,7 +429,7 @@ def _post_constraint(self, cpm_expr):
return self.chc_model.element(chcrhs, self.solver_vars(lhs.args[0]),
self.solver_var(lhs.args[1])).post()
elif lhs.name == 'mod':
# catch tricky-to-find ortools limitation
# catch tricky-to-find ortools limitation #TODO: check
divisor = lhs.args[1]
if not is_num(divisor):
if divisor.lb <= 0 and divisor.ub >= 0:
Expand All @@ -477,6 +456,10 @@ def _post_constraint(self, cpm_expr):
assert (len(cpm_expr.args) == 2) # args = [array, table]
array, table = self.solver_vars(cpm_expr.args)
return self.chc_model.table(array, table).post()
elif cpm_expr.name == 'InDomain':
assert (len(cpm_expr.args) == 2) # args = [array, table]
expr, table = self.solver_vars(cpm_expr.args)
return self.chc_model.member(expr, table).post()
elif cpm_expr.name == "cumulative":
start, dur, end, demand, cap = self.solver_vars(cpm_expr.args)
# Everything given to cumulative in Choco needs to be a variable.
Expand Down Expand Up @@ -525,7 +508,9 @@ def _post_constraint(self, cpm_expr):

# a direct constraint, pass to solver
elif isinstance(cpm_expr, DirectConstraint):
return cpm_expr.callSolver(self, self.chc_model)
c = cpm_expr.callSolver(self, self.chc_model)
print(c)
return c.post()

# else
raise NotImplementedError(cpm_expr) # if you reach this... please report on github
15 changes: 14 additions & 1 deletion tests/test_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from cpmpy import *
from cpmpy.solvers import CPM_gurobi, CPM_pysat, CPM_minizinc, CPM_pysdd, CPM_z3, CPM_exact
from cpmpy.solvers import CPM_gurobi, CPM_pysat, CPM_minizinc, CPM_pysdd, CPM_z3, CPM_exact, CPM_choco


class TestDirectORTools(unittest.TestCase):
Expand Down Expand Up @@ -127,4 +127,17 @@ def test_direct_poly(self):

self.assertEqual(y.value(), poly_val)

@pytest.mark.skipif(not CPM_choco.supported(),
reason="pychoco not installed")
class TestDirectChoco(unittest.TestCase):

def test_direct_global(self):
iv = intvar(1,9, shape=3)

model = SolverLookup.get("choco")

model += DirectConstraint("increasing", iv)
model += iv[1] < iv[0]

self.assertFalse(model.solve())

0 comments on commit ecf379b

Please sign in to comment.