diff --git a/cpmpy/solvers/exact.py b/cpmpy/solvers/exact.py index 5a290a6de..01c54287b 100644 --- a/cpmpy/solvers/exact.py +++ b/cpmpy/solvers/exact.py @@ -37,6 +37,8 @@ from ..transformations.normalize import toplevel_list from ..expressions.globalconstraints import DirectConstraint from ..exceptions import NotSupportedError +from ..expressions.utils import flatlist + import numpy as np import numbers @@ -589,6 +591,10 @@ def solution_hint(self, cpm_vars, vals): :param cpm_vars: list of CPMpy variables :param vals: list of (corresponding) values for the variables """ + + cpm_vars = flatlist(cpm_vars) + vals = flatlist(vals) + assert (len(cpm_vars) == len(vals)), "Variables and values must have the same size for hinting" try: pkg_resources.require("exact>=1.1.5") self.xct_solver.setSolutionHints(self.solver_vars(cpm_vars), vals) diff --git a/cpmpy/solvers/ortools.py b/cpmpy/solvers/ortools.py index 8878d9d30..d4879761e 100644 --- a/cpmpy/solvers/ortools.py +++ b/cpmpy/solvers/ortools.py @@ -32,7 +32,7 @@ from ..expressions.globalconstraints import DirectConstraint from ..expressions.variables import _NumVarImpl, _IntVarImpl, _BoolVarImpl, NegBoolView, boolvar from ..expressions.globalconstraints import GlobalConstraint -from ..expressions.utils import is_num, is_any_list, eval_comparison +from ..expressions.utils import is_num, is_any_list, eval_comparison, flatlist from ..transformations.decompose_global import decompose_in_tree from ..transformations.get_variables import get_variables from ..transformations.flatten_model import flatten_constraint, flatten_objective @@ -521,6 +521,10 @@ def solution_hint(self, cpm_vars, vals): :param vals: list of (corresponding) values for the variables """ self.ort_model.ClearHints() # because add just appends + + cpm_vars = flatlist(cpm_vars) + vals = flatlist(vals) + assert (len(cpm_vars) == len(vals)), "Variables and values must have the same size for hinting" for (cpm_var, val) in zip(cpm_vars, vals): self.ort_model.AddHint(self.solver_var(cpm_var), val) diff --git a/cpmpy/solvers/pysat.py b/cpmpy/solvers/pysat.py index 022c6056d..66003b7e1 100644 --- a/cpmpy/solvers/pysat.py +++ b/cpmpy/solvers/pysat.py @@ -34,7 +34,7 @@ from ..expressions.core import Expression, Comparison, Operator, BoolVal from ..expressions.variables import _BoolVarImpl, NegBoolView, boolvar from ..expressions.globalconstraints import DirectConstraint -from ..expressions.utils import is_any_list, is_int +from ..expressions.utils import is_int, flatlist from ..transformations.decompose_global import decompose_in_tree from ..transformations.get_variables import get_variables from ..transformations.flatten_model import flatten_constraint @@ -316,6 +316,11 @@ def solution_hint(self, cpm_vars, vals): :param cpm_vars: list of CPMpy variables :param vals: list of (corresponding) values for the variables """ + + cpm_vars = flatlist(cpm_vars) + vals = flatlist(vals) + assert (len(cpm_vars) == len(vals)), "Variables and values must have the same size for hinting" + literals = [] for (cpm_var, val) in zip(cpm_vars, vals): lit = self.solver_var(cpm_var) diff --git a/tests/test_solvers_solhint.py b/tests/test_solvers_solhint.py index 9d7c24324..00dbeac61 100644 --- a/tests/test_solvers_solhint.py +++ b/tests/test_solvers_solhint.py @@ -32,6 +32,9 @@ def test_hints(self): slv.solution_hint([a,b], [True,True]) self.assertTrue(slv.solve(**args)) # should also work with an UNSAT hint + slv.solution_hint([a,[b]], [[[False]], True]) # check nested lists + self.assertTrue(slv.solve(**args)) + except NotSupportedError: continue