Skip to content

Commit

Permalink
solve and most in post_constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
Dimosts committed Sep 13, 2023
1 parent c7a7c5a commit 6f3ccbf
Showing 1 changed file with 31 additions and 33 deletions.
64 changes: 31 additions & 33 deletions cpmpy/solvers/choco.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
"""
import sys # for stdout checking
import numpy as np
from cpmpy.transformations.normalize import toplevel_list
import time

from ..transformations.normalize import toplevel_list
from .solver_interface import SolverInterface, SolverStatus, ExitStatus
from ..exceptions import NotSupportedError
from ..expressions.core import Expression, Comparison, Operator, BoolVal
Expand Down Expand Up @@ -114,11 +115,13 @@ def solve(self, time_limit=None, **kwargs):

# call the solver, with parameters
self.chc_solver = self.chc_model.get_solver()
start = time.time()
self.chc_status = self.chc_solver.solve()
end = time.time()

# new status, translate runtime
self.cpm_status = SolverStatus(self.name)
#self.cpm_status.runtime = self.chc_status.WallTime()
self.cpm_status.runtime = end - start

"""
# translate exit status
Expand All @@ -135,23 +138,21 @@ def solve(self, time_limit=None, **kwargs):
self.cpm_status.exitstatus = ExitStatus.UNKNOWN
else: # another?
raise NotImplementedError(self.ort_status) # a new status type was introduced, please report on github
"""

# True/False depending on self.cpm_status
has_sol = self._solve_return(self.cpm_status)
has_sol = self.chc_status

# translate solution values (of user specified variables only)
self.objective_value_ = None
if has_sol:
# fill in variable values
for cpm_var in self.user_vars:
cpm_var._value = self.ort_solver.Value(self.solver_var(cpm_var))
if isinstance(cpm_var, _BoolVarImpl):
cpm_var._value = bool(cpm_var._value) # ort value is always an int
cpm_var._value = self.solver_var(cpm_var).get_value()

# translate objective
if self.has_objective():
self.objective_value_ = self.ort_solver.ObjectiveValue()
"""
# if self.has_objective():
# self.objective_value_ = self.ort_solver.ObjectiveValue()
return True

def solveAll(self, display=None, time_limit=None, solution_limit=None, call_from_model=False, **kwargs):
Expand Down Expand Up @@ -246,7 +247,7 @@ def _make_numexpr(self, cpm_expr):
lhs = cpm_expr.args[0]
rhs = cpm_expr.args[1]
op = cpm_expr.name
if op == "==": op = "="
if op == "==": op = "=" # choco uses "=" for equality

if is_num(lhs): #TODO can this happen to be num in lhs?? I think no
return cpm_expr
Expand All @@ -255,7 +256,6 @@ def _make_numexpr(self, cpm_expr):
if isinstance(lhs, _NumVarImpl): # _BoolVarImpl is subclass of _NumVarImpl
return self.chc_model.arithm(self.solver_var(lhs), op, self.solver_var(rhs))

print("cpm_expr: ", cpm_expr)
# sum or weighted sum
if isinstance(lhs, Operator):
if lhs.name == 'sum':
Expand Down Expand Up @@ -292,7 +292,7 @@ def transform(self, cpm_expr):
cpm_cons = flatten_constraint(cpm_cons) # flat normal form
cpm_cons = canonical_comparison(cpm_cons)
cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum'])) # constraints that support reification
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"])) # supports >, <, !=
cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"])) # support >, <, != TODO: Maybe not needed
cpm_cons = only_bv_implies(cpm_cons) # everything that can create
# reified expr must go before this

Expand Down Expand Up @@ -358,7 +358,10 @@ def _post_constraint(self, cpm_expr):
lhs = self.solver_var(cpm_expr.args[0])
if isinstance(cpm_expr.args[1], _BoolVarImpl):
# bv -> bv
return lhs.implies(self.solver_var(cpm_expr.args[1]))
print(lhs)
print(cpm_expr.args[1])
return lhs.__ge__(self.solver_var(cpm_expr.args[1]))
return lhs.imp(self.solver_var(cpm_expr.args[1])).post()
else:
raise NotImplementedError("Not a known supported Choco Operator '{}' {}".format(
cpm_expr.name, cpm_expr))
Expand All @@ -370,45 +373,40 @@ def _post_constraint(self, cpm_expr):
lhs = cpm_expr.args[0]
chcrhs = self.solver_var(cpm_expr.args[1])

if isinstance(lhs, _NumVarImpl):
# both are variables, do python comparison over ORT variables
print("Arithm")
return self.chc_model.arithm(self.solver_var(lhs), cpm_expr.name, chcrhs).post()
elif isinstance(lhs, Operator) and (lhs.name == 'sum' or lhs.name == 'wsum' or lhs.name == "sub"):
if isinstance(lhs, _NumVarImpl) or isinstance(lhs, Operator) and (lhs.name == 'sum' or lhs.name == 'wsum' or lhs.name == "sub"):
# a BoundedLinearExpression LHS, special case, like in objective
chc_numexpr = self._make_numexpr(cpm_expr)
return chc_numexpr.post()
elif cpm_expr.name == '==':
# NumExpr == IV, supported by ortools (thanks to `only_numexpr_equality()` transformation)
# NumExpr == IV, supported by Choco (thanks to `only_numexpr_equality()` transformation)
if lhs.name == 'min':
return self.chc_model.min(chcrhs, self.solver_vars(lhs.args))
return self.chc_model.min(chcrhs, self.solver_vars(lhs.args)).post()
elif lhs.name == 'max':
return self.chc_model.max(chcrhs, self.solver_vars(lhs.args))
return self.chc_model.max(chcrhs, self.solver_vars(lhs.args)).post()
elif lhs.name == 'abs':
return self.chc_model.absolute(chcrhs, self.solver_var(lhs.args[0]))
return self.chc_model.absolute(chcrhs, self.solver_var(lhs.args[0])).post()
elif lhs.name == 'mul':
return self.chc_model.times(self.solver_vars(lhs.args[0]), self.solver_vars(lhs.args[1]), chcrhs)
return self.chc_model.times(self.solver_vars(lhs.args[0]), self.solver_vars(lhs.args[1]), chcrhs).post()
elif lhs.name == 'div':
return self.chc_model.div(self.solver_vars(lhs.args[0]), self.solver_vars(lhs.args[1]), chcrhs)
return self.chc_model.div(self.solver_vars(lhs.args[0]), self.solver_vars(lhs.args[1]), chcrhs).post()
elif lhs.name == 'element':
# arr[idx]==rvar (arr=arg0,idx=arg1), ort: (idx,arr,target)
return self.chc_model.element(self.solver_var(lhs.args[1]),
self.solver_vars(lhs.args[0]), chcrhs)
return self.chc_model.element(chcrhs, self.solver_vars(lhs.args[0]), self.solver_var(lhs.args[1])).post()
elif lhs.name == 'mod':
# catch tricky-to-find ortools limitation
divisor = lhs.args[1]
if not is_num(divisor):
if divisor.lb <= 0 and divisor.ub >= 0:
raise Exception(
f"Expression '{lhs}': or-tools does not accept a 'modulo' operation where '0' is in the domain of the divisor {divisor}:domain({divisor.lb}, {divisor.ub}). Even if you add a constraint that it can not be '0'. You MUST use a variable that is defined to be higher or lower than '0'.")
return self.ort_model.AddModuloEquality(ortrhs, *self.solver_vars(lhs.args))
f"Expression '{lhs}': Choco does not accept a 'modulo' operation where '0' is in the domain of the divisor {divisor}:domain({divisor.lb}, {divisor.ub}). Even if you add a constraint that it can not be '0'. You MUST use a variable that is defined to be higher or lower than '0'.")
return self.chc_model.mod(self.solver_vars(lhs.args[0]), self.solver_vars(lhs.args)[1], chcrhs).post()
elif lhs.name == 'pow':
# only `POW(b,2) == IV` supported, post as b*b == IV
assert (lhs.args[1] == 2), "Ort: 'pow', only var**2 supported, no other exponents"
b = self.solver_var(lhs.args[0])
return self.ort_model.AddMultiplicationEquality(ortrhs, [b,b])
#assert (lhs.args[1] == 2), "Ort: 'pow', only var**2 supported, no other exponents"
chclhs = self.solver_var(lhs.args[0])
return self.chc_model.pow(self.solver_vars(lhs.args[0]), self.solver_vars(lhs.args)[1], chcrhs).post()
#return self.chc_model.arithm(chcrhs, "=", chclhs.pow()).post()
raise NotImplementedError(
"Not a known supported ORTools left-hand-side '{}' {}".format(lhs.name, cpm_expr))
"Not a known supported Choco left-hand-side '{}' {}".format(lhs.name, cpm_expr))

# base (Boolean) global constraints
elif isinstance(cpm_expr, GlobalConstraint):
Expand Down

0 comments on commit 6f3ccbf

Please sign in to comment.