diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 5e4590010df..ad4ea1f9c24 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -46,11 +46,14 @@ """ import collections +import numbers import threading import time -from typing import Optional +from typing import Optional, Iterable, Union, Tuple, Sequence, overload import warnings +import numpy as np + from ortools.sat import cp_model_pb2 from ortools.sat import sat_parameters_pb2 from ortools.sat.python import cp_model_helper as cmh @@ -94,6 +97,13 @@ PORTFOLIO_SEARCH = sat_parameters_pb2.SatParameters.PORTFOLIO_SEARCH LP_SEARCH = sat_parameters_pb2.SatParameters.LP_SEARCH +# Type aliases +IntegralT = Union[numbers.Integral, np.integer, int] +LiteralT = Union["IntVar", "_NotBooleanVariable", IntegralT, bool] +VariableT = Union["IntVar", "_ProductCst", IntegralT] +LinearExprT = Union["LinearExpr", "IntVar", "_ProductCst", IntegralT] +ArcT = Tuple[IntegralT, IntegralT, LiteralT] + def DisplayBounds(bounds): """Displays a flattened list of intervals.""" @@ -144,7 +154,7 @@ def ShortExprName(model, e): return str(e) -class LinearExpr(object): +class LinearExpr: """Holds an integer linear expression. A linear expression is built from integer constants and variables. @@ -657,15 +667,15 @@ def __init__(self, model, domain, name): self.__var.domain.extend(domain.FlattenedIntervals()) self.__var.name = name - def Index(self): + def Index(self) -> int: """Returns the index of the variable in the model.""" return self.__index - def Proto(self): + def Proto(self) -> cp_model_pb2.IntegerVariableProto: """Returns the variable protobuf.""" return self.__var - def IsEqualTo(self, other): + def IsEqualTo(self, other) -> bool: """Returns true if self == other in the python sense.""" if not isinstance(other, IntVar): return False @@ -686,10 +696,10 @@ def __str__(self): def __repr__(self): return "%s(%s)" % (self.__var.name, DisplayBounds(self.__var.domain)) - def Name(self): + def Name(self) -> str: return self.__var.name - def Not(self): + def Not(self) -> "_NotBooleanVariable": """Returns the negation of a Boolean variable. This method implements the logical negation of a Boolean variable. @@ -709,7 +719,9 @@ def Not(self): class _NotBooleanVariable(LinearExpr): """Negation of a boolean variable.""" - def __init__(self, boolvar): + __boolvar: IntVar + + def __init__(self, boolvar: IntVar): self.__boolvar = boolvar def Index(self): @@ -727,7 +739,7 @@ def __bool__(self): ) -class BoundedLinearExpression(object): +class BoundedLinearExpression: """Represents a linear constraint: `lb <= linear expression <= ub`. The only use of this class is to be added to the CpModel through @@ -799,7 +811,7 @@ def __bool__(self): ) -class Constraint(object): +class Constraint: """Base class for constraints. Constraints are built by the CpModel through the Add methods. @@ -814,10 +826,20 @@ class Constraint(object): model.Add(x + 2 * y == 5).OnlyEnforceIf(b.Not()) """ + __constraint: cp_model_pb2.ConstraintProto + def __init__(self, constraints): self.__index = len(constraints) self.__constraint = constraints.add() + @overload + def OnlyEnforceIf(self, boolvar: Iterable[LiteralT]): + ... + + @overload + def OnlyEnforceIf(self, *boolvar: LiteralT): + ... + def OnlyEnforceIf(self, *boolvar): """Adds an enforcement literal to the constraint. @@ -853,20 +875,20 @@ def WithName(self, name): self.__constraint.ClearField("name") return self - def Name(self): + def Name(self) -> str: """Returns the name of the constraint.""" return self.__constraint.name - def Index(self): + def Index(self) -> int: """Returns the index of the constraint in the model.""" return self.__index - def Proto(self): + def Proto(self) -> cp_model_pb2.ConstraintProto: """Returns the constraint protobuf.""" return self.__constraint -class IntervalVar(object): +class IntervalVar: """Represents an Interval variable. An interval variable is both a constraint and a variable. It is defined by @@ -884,6 +906,8 @@ class IntervalVar(object): intervals into the schedule. """ + __model: cp_model_pb2.CpModelProto + def __init__(self, model, start, size, end, is_present_index, name): self.__model = model # As with the IntVar::__init__ method, we hack the __init__ method to @@ -911,7 +935,7 @@ def Index(self): """Returns the index of the interval constraint in the model.""" return self.__index - def Proto(self): + def Proto(self) -> cp_model_pb2.IntervalConstraintProto: """Returns the interval protobuf.""" return self.__ct.interval @@ -981,7 +1005,7 @@ def ObjectIsAFalseLiteral(literal): return False -class CpModel(object): +class CpModel: """Methods for building a CP model. Methods beginning with: @@ -990,22 +1014,24 @@ class CpModel(object): * ```Add``` create new constraints and add them to the model. """ + __model: cp_model_pb2.CpModelProto + def __init__(self): self.__model = cp_model_pb2.CpModelProto() self.__constant_map = {} # Naming. - def Name(self): + def Name(self) -> str: """Returns the name of the model.""" return self.__model.name - def SetName(self, name): + def SetName(self, name: str): """Sets the name of the model.""" self.__model.name = name # Integer variable. - def NewIntVar(self, lb, ub, name): + def NewIntVar(self, lb: IntegralT, ub: IntegralT, name: str) -> IntVar: """Create an integer variable with domain [lb, ub]. The CP-SAT solver is limited to integer variables. If you have fractional @@ -1023,7 +1049,7 @@ def NewIntVar(self, lb, ub, name): return IntVar(self.__model, Domain(lb, ub), name) - def NewIntVarFromDomain(self, domain, name): + def NewIntVarFromDomain(self, domain, name: str) -> IntVar: """Create an integer variable from a domain. A domain is a set of integers specified by a collection of intervals. @@ -1039,21 +1065,23 @@ def NewIntVarFromDomain(self, domain, name): """ return IntVar(self.__model, domain, name) - def NewBoolVar(self, name): + def NewBoolVar(self, name: str) -> IntVar: """Creates a 0-1 variable with the given name.""" return IntVar(self.__model, Domain(0, 1), name) - def NewConstant(self, value): + def NewConstant(self, value: IntegralT) -> IntVar: """Declares a constant integer.""" return IntVar(self.__model, self.GetOrMakeIndexFromConstant(value), None) # Linear constraints. - def AddLinearConstraint(self, linear_expr, lb, ub): + def AddLinearConstraint( + self, linear_expr, lb: IntegralT, ub: IntegralT + ) -> Constraint: """Adds the constraint: `lb <= linear_expr <= ub`.""" return self.AddLinearExpressionInDomain(linear_expr, Domain(lb, ub)) - def AddLinearExpressionInDomain(self, linear_expr, domain): + def AddLinearExpressionInDomain(self, linear_expr, domain) -> Constraint: """Adds the constraint: `linear_expr` in `domain`.""" if isinstance(linear_expr, LinearExpr): ct = Constraint(self.__model.constraints) @@ -1085,7 +1113,7 @@ def AddLinearExpressionInDomain(self, linear_expr, domain): + ")" ) - def Add(self, ct): + def Add(self, ct: Union[BoundedLinearExpression, bool]) -> Constraint: """Adds a `BoundedLinearExpression` to the model. Args: @@ -1107,6 +1135,14 @@ def Add(self, ct): # General Integer Constraints. + @overload + def AddAllDifferent(self, expressions: Iterable[LinearExprT]) -> Constraint: + ... + + @overload + def AddAllDifferent(self, *expressions: LinearExprT) -> Constraint: + ... + def AddAllDifferent(self, *expressions): """Adds AllDifferent(expressions). @@ -1126,7 +1162,9 @@ def AddAllDifferent(self, *expressions): ) return ct - def AddElement(self, index, variables, target): + def AddElement( + self, index: VariableT, variables: Iterable[VariableT], target: VariableT + ) -> Constraint: """Adds the element constraint: `variables[index] == target`. Args: @@ -1151,7 +1189,7 @@ def AddElement(self, index, variables, target): model_ct.element.target = self.GetOrMakeIndex(target) return ct - def AddCircuit(self, arcs): + def AddCircuit(self, arcs: Iterable[ArcT]) -> Constraint: """Adds Circuit(arcs). Adds a circuit constraint from a sparse list of arcs that encode the graph. @@ -1186,7 +1224,7 @@ def AddCircuit(self, arcs): model_ct.circuit.literals.append(lit) return ct - def AddMultipleCircuit(self, arcs): + def AddMultipleCircuit(self, arcs: Iterable[ArcT]) -> Constraint: """Adds a multiple circuit constraint, aka the "VRP" constraint. The direct graph where arc #i (from tails[i] to head[i]) is present iff @@ -1223,7 +1261,9 @@ def AddMultipleCircuit(self, arcs): model_ct.routes.literals.append(lit) return ct - def AddAllowedAssignments(self, variables, tuples_list): + def AddAllowedAssignments( + self, variables: Sequence[VariableT], tuples_list: Iterable[Sequence[IntegralT]] + ) -> Constraint: """Adds AllowedAssignments(variables, tuples_list). An AllowedAssignments constraint is a constraint on an array of variables, @@ -1263,7 +1303,9 @@ def AddAllowedAssignments(self, variables, tuples_list): model_ct.table.values.extend(ar) return ct - def AddForbiddenAssignments(self, variables, tuples_list): + def AddForbiddenAssignments( + self, variables: Sequence[VariableT], tuples_list: Iterable[Sequence[IntegralT]] + ) -> Constraint: """Adds AddForbiddenAssignments(variables, [tuples_list]). A ForbiddenAssignments constraint is a constraint on an array of variables @@ -1295,8 +1337,12 @@ def AddForbiddenAssignments(self, variables, tuples_list): return ct def AddAutomaton( - self, transition_variables, starting_state, final_states, transition_triples - ): + self, + transition_variables: Iterable[VariableT], + starting_state: IntegralT, + final_states: Iterable[IntegralT], + transition_triples: Iterable[Tuple[IntegralT, IntegralT, IntegralT]], + ) -> Constraint: """Adds an automaton constraint. An automaton constraint takes a list of variables (of size *n*), an initial @@ -1368,7 +1414,9 @@ def AddAutomaton( model_ct.automaton.transition_head.append(head) return ct - def AddInverse(self, variables, inverse_variables): + def AddInverse( + self, variables: Sequence[VariableT], inverse_variables: Sequence[VariableT] + ) -> Constraint: """Adds Inverse(variables, inverse_variables). An inverse constraint enforces that if `variables[i]` is assigned a value @@ -1401,7 +1449,13 @@ def AddInverse(self, variables, inverse_variables): ) return ct - def AddReservoirConstraint(self, times, level_changes, min_level, max_level): + def AddReservoirConstraint( + self, + times: Iterable[LinearExprT], + level_changes: Iterable[LinearExprT], + min_level: int, + max_level: int, + ) -> Constraint: """Adds Reservoir(times, level_changes, min_level, max_level). Maintains a reservoir level within bounds. The water level starts at 0, and @@ -1459,8 +1513,13 @@ def AddReservoirConstraint(self, times, level_changes, min_level, max_level): return ct def AddReservoirConstraintWithActive( - self, times, level_changes, actives, min_level, max_level - ): + self, + times: Iterable[LinearExprT], + level_changes: Iterable[LinearExprT], + actives: Iterable[LiteralT], + min_level: int, + max_level: int, + ) -> Constraint: """Adds Reservoir(times, level_changes, actives, min_level, max_level). Maintains a reservoir level within bounds. The water level starts at 0, and @@ -1523,13 +1582,15 @@ def AddReservoirConstraintWithActive( [self.ParseLinearExpression(x) for x in level_changes] ) model_ct.reservoir.active_literals.extend( - [self.GetOrMakeIndex(x) for x in actives] + [self.GetOrMakeBooleanIndex(x) for x in actives] ) model_ct.reservoir.min_level = min_level model_ct.reservoir.max_level = max_level return ct - def AddMapDomain(self, var, bool_var_array, offset=0): + def AddMapDomain( + self, var: IntVar, bool_var_array: Iterable[IntVar], offset: IntegralT = 0 + ): """Adds `var == i + offset <=> bool_var_array[i] == true for all i`.""" for i, bool_var in enumerate(bool_var_array): @@ -1550,7 +1611,7 @@ def AddMapDomain(self, var, bool_var_array, offset=0): if offset + i + 1 <= INT_MAX: model_ct.linear.domain.extend([offset + i + 1, INT_MAX]) - def AddImplication(self, a, b): + def AddImplication(self, a: LiteralT, b: LiteralT) -> Constraint: """Adds `a => b` (`a` implies `b`).""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] @@ -1558,6 +1619,14 @@ def AddImplication(self, a, b): model_ct.enforcement_literal.append(self.GetOrMakeBooleanIndex(a)) return ct + @overload + def AddBoolOr(self, literals: Iterable[LiteralT]) -> Constraint: + ... + + @overload + def AddBoolOr(self, *literals: LiteralT) -> Constraint: + ... + def AddBoolOr(self, *literals): """Adds `Or(literals) == true`: Sum(literals) >= 1.""" ct = Constraint(self.__model.constraints) @@ -1567,10 +1636,26 @@ def AddBoolOr(self, *literals): ) return ct + @overload + def AddAtLeastOne(self, literals: Iterable[LiteralT]) -> Constraint: + ... + + @overload + def AddAtLeastOne(self, *literals: LiteralT) -> Constraint: + ... + def AddAtLeastOne(self, *literals): """Same as `AddBoolOr`: `Sum(literals) >= 1`.""" return self.AddBoolOr(*literals) + @overload + def AddAtMostOne(self, literals: Iterable[LiteralT]) -> Constraint: + ... + + @overload + def AddAtMostOne(self, *literals: LiteralT) -> Constraint: + ... + def AddAtMostOne(self, *literals): """Adds `AtMostOne(literals)`: `Sum(literals) <= 1`.""" ct = Constraint(self.__model.constraints) @@ -1580,6 +1665,14 @@ def AddAtMostOne(self, *literals): ) return ct + @overload + def AddExactlyOne(self, literals: Iterable[LiteralT]) -> Constraint: + ... + + @overload + def AddExactlyOne(self, *literals: LiteralT) -> Constraint: + ... + def AddExactlyOne(self, *literals): """Adds `ExactlyOne(literals)`: `Sum(literals) == 1`.""" ct = Constraint(self.__model.constraints) @@ -1589,6 +1682,14 @@ def AddExactlyOne(self, *literals): ) return ct + @overload + def AddBoolAnd(self, literals: Iterable[LiteralT]) -> Constraint: + ... + + @overload + def AddBoolAnd(self, *literals: LiteralT) -> Constraint: + ... + def AddBoolAnd(self, *literals): """Adds `And(literals) == true`.""" ct = Constraint(self.__model.constraints) @@ -1598,6 +1699,14 @@ def AddBoolAnd(self, *literals): ) return ct + @overload + def AddBoolXOr(self, literals: Iterable[LiteralT]) -> Constraint: + ... + + @overload + def AddBoolXOr(self, *literals: LiteralT) -> Constraint: + ... + def AddBoolXOr(self, *literals): """Adds `XOr(literals) == true`. @@ -1617,7 +1726,9 @@ def AddBoolXOr(self, *literals): ) return ct - def AddMinEquality(self, target, exprs): + def AddMinEquality( + self, target: LinearExprT, exprs: Iterable[LinearExprT] + ) -> Constraint: """Adds `target == Min(exprs)`.""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] @@ -1627,7 +1738,9 @@ def AddMinEquality(self, target, exprs): model_ct.lin_max.target.CopyFrom(self.ParseLinearExpression(target, True)) return ct - def AddMaxEquality(self, target, exprs): + def AddMaxEquality( + self, target: LinearExprT, exprs: Iterable[LinearExprT] + ) -> Constraint: """Adds `target == Max(exprs)`.""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] @@ -1635,7 +1748,9 @@ def AddMaxEquality(self, target, exprs): model_ct.lin_max.target.CopyFrom(self.ParseLinearExpression(target)) return ct - def AddDivisionEquality(self, target, num, denom): + def AddDivisionEquality( + self, target: LinearExprT, num: LinearExprT, denom: LinearExprT + ) -> Constraint: """Adds `target == num // denom` (integer division rounded towards 0).""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] @@ -1644,7 +1759,7 @@ def AddDivisionEquality(self, target, num, denom): model_ct.int_div.target.CopyFrom(self.ParseLinearExpression(target)) return ct - def AddAbsEquality(self, target, expr): + def AddAbsEquality(self, target: LinearExprT, expr: LinearExprT) -> Constraint: """Adds `target == Abs(var)`.""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] @@ -1653,7 +1768,9 @@ def AddAbsEquality(self, target, expr): model_ct.lin_max.target.CopyFrom(self.ParseLinearExpression(target)) return ct - def AddModuloEquality(self, target, var, mod): + def AddModuloEquality( + self, target: LinearExprT, var: LinearExprT, mod: LinearExprT + ) -> Constraint: """Adds `target = var % mod`.""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] @@ -1662,7 +1779,11 @@ def AddModuloEquality(self, target, var, mod): model_ct.int_mod.target.CopyFrom(self.ParseLinearExpression(target)) return ct - def AddMultiplicationEquality(self, target, *expressions): + def AddMultiplicationEquality( + self, + target: LinearExprT, + *expressions: Union[Iterable[LinearExprT], LinearExprT], + ) -> Constraint: """Adds `target == expressions[0] * .. * expressions[n]`.""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] @@ -1677,7 +1798,9 @@ def AddMultiplicationEquality(self, target, *expressions): # Scheduling support - def NewIntervalVar(self, start, size, end, name): + def NewIntervalVar( + self, start: LinearExprT, size: LinearExprT, end: LinearExprT, name: str + ) -> IntervalVar: """Creates an interval variable from start, size, and end. An interval variable is a constraint, that is itself used in other @@ -1712,7 +1835,9 @@ def NewIntervalVar(self, start, size, end, name): raise TypeError("cp_model.NewIntervalVar: end must be affine or constant.") return IntervalVar(self.__model, start_expr, size_expr, end_expr, None, name) - def NewFixedSizeIntervalVar(self, start, size, name): + def NewFixedSizeIntervalVar( + self, start: LinearExprT, size: IntegralT, name: str + ) -> IntervalVar: """Creates an interval variable from start, and a fixed size. An interval variable is a constraint, that is itself used in other @@ -1737,7 +1862,14 @@ def NewFixedSizeIntervalVar(self, start, size, name): ) return IntervalVar(self.__model, start_expr, size_expr, end_expr, None, name) - def NewOptionalIntervalVar(self, start, size, end, is_present, name): + def NewOptionalIntervalVar( + self, + start: LinearExprT, + size: LinearExprT, + end: LinearExprT, + is_present: LiteralT, + name: str, + ) -> IntervalVar: """Creates an optional interval var from start, size, end, and is_present. An optional interval variable is a constraint, that is itself used in other @@ -1781,7 +1913,9 @@ def NewOptionalIntervalVar(self, start, size, end, is_present, name): self.__model, start_expr, size_expr, end_expr, is_present_index, name ) - def NewOptionalFixedSizeIntervalVar(self, start, size, is_present, name): + def NewOptionalFixedSizeIntervalVar( + self, start: LinearExprT, size: IntegralT, is_present: LiteralT, name: str + ) -> IntervalVar: """Creates an interval variable from start, and a fixed size. An interval variable is a constraint, that is itself used in other @@ -1811,7 +1945,7 @@ def NewOptionalFixedSizeIntervalVar(self, start, size, is_present, name): self.__model, start_expr, size_expr, end_expr, is_present_index, name ) - def AddNoOverlap(self, interval_vars): + def AddNoOverlap(self, interval_vars: Iterable[IntervalVar]) -> Constraint: """Adds NoOverlap(interval_vars). A NoOverlap constraint ensures that all present intervals do not overlap @@ -1830,7 +1964,9 @@ def AddNoOverlap(self, interval_vars): ) return ct - def AddNoOverlap2D(self, x_intervals, y_intervals): + def AddNoOverlap2D( + self, x_intervals: Iterable[IntervalVar], y_intervals: Iterable[IntervalVar] + ) -> Constraint: """Adds NoOverlap2D(x_intervals, y_intervals). A NoOverlap2D constraint ensures that all present rectangles do not overlap @@ -1857,7 +1993,12 @@ def AddNoOverlap2D(self, x_intervals, y_intervals): ) return ct - def AddCumulative(self, intervals, demands, capacity): + def AddCumulative( + self, + intervals: Iterable[IntervalVar], + demands: Iterable[VariableT], + capacity: Iterable[VariableT], + ) -> Constraint: """Adds Cumulative(intervals, demands, capacity). This constraint enforces that: @@ -1888,7 +2029,7 @@ def AddCumulative(self, intervals, demands, capacity): return ct # Support for deep copy. - def CopyFrom(self, other_model): + def CopyFrom(self, other_model: "CpModel"): """Reset the model, and creates a new one from a CpModelProto instance.""" self.__model.CopyFrom(other_model.Proto()) @@ -1898,7 +2039,7 @@ def CopyFrom(self, other_model): if len(var.domain) == 2 and var.domain[0] == var.domain[1]: self.__constant_map[var.domain[0]] = i - def GetBoolVarFromProtoIndex(self, index): + def GetBoolVarFromProtoIndex(self, index: int) -> IntVar: """Returns an already created Boolean variable from its index.""" if index < 0 or index >= len(self.__model.variables): raise ValueError(f"GetBoolVarFromProtoIndex: out of bound index {index}") @@ -1911,13 +2052,13 @@ def GetBoolVarFromProtoIndex(self, index): return IntVar(self.__model, index, None) - def GetIntVarFromProtoIndex(self, index): + def GetIntVarFromProtoIndex(self, index: int) -> IntVar: """Returns an already created integer variable from its index.""" if index < 0 or index >= len(self.__model.variables): raise ValueError(f"GetIntVarFromProtoIndex: out of bound index {index}") return IntVar(self.__model, index, None) - def GetIntervalVarFromProtoIndex(self, index): + def GetIntervalVarFromProtoIndex(self, index: int) -> IntervalVar: """Returns an already created interval variable from its index.""" if index < 0 or index >= len(self.__model.constraints): raise ValueError( @@ -1937,14 +2078,14 @@ def GetIntervalVarFromProtoIndex(self, index): def __str__(self): return str(self.__model) - def Proto(self): + def Proto(self) -> cp_model_pb2.CpModelProto: """Returns the underlying CpModelProto.""" return self.__model - def Negated(self, index): + def Negated(self, index) -> int: return -index - 1 - def GetOrMakeIndex(self, arg): + def GetOrMakeIndex(self, arg: VariableT): """Returns the index of a variable, its negation, or a number.""" if isinstance(arg, IntVar): return arg.Index() @@ -1960,7 +2101,7 @@ def GetOrMakeIndex(self, arg): else: raise TypeError("NotSupported: model.GetOrMakeIndex(" + str(arg) + ")") - def GetOrMakeBooleanIndex(self, arg): + def GetOrMakeBooleanIndex(self, arg: LiteralT): """Returns an index from a boolean expression.""" if isinstance(arg, IntVar): self.AssertIsBooleanVariable(arg) @@ -1976,12 +2117,12 @@ def GetOrMakeBooleanIndex(self, arg): "NotSupported: model.GetOrMakeBooleanIndex(" + str(arg) + ")" ) - def GetIntervalIndex(self, arg): + def GetIntervalIndex(self, arg: IntervalVar): if not isinstance(arg, IntervalVar): raise TypeError("NotSupported: model.GetIntervalIndex(%s)" % arg) return arg.Index() - def GetOrMakeIndexFromConstant(self, value): + def GetOrMakeIndexFromConstant(self, value: IntegralT): if value in self.__constant_map: return self.__constant_map[value] index = len(self.__model.variables) @@ -1996,7 +2137,9 @@ def VarIndexToVarProto(self, var_index): else: return self.__model.variables[-var_index - 1] - def ParseLinearExpression(self, linear_expr, negate=False): + def ParseLinearExpression( + self, linear_expr, negate=False + ) -> cp_model_pb2.LinearExpressionProto: """Returns a LinearExpressionProto built from a LinearExpr instance.""" result = cp_model_pb2.LinearExpressionProto() mult = -1 if negate else 1 @@ -2019,7 +2162,7 @@ def ParseLinearExpression(self, linear_expr, negate=False): result.coeffs.append(c * mult) return result - def _SetObjective(self, obj, minimize): + def _SetObjective(self, obj: Union[IntVar, LinearExpr, IntegralT], minimize: bool): """Sets the objective of the model.""" self.ClearObjective() if isinstance(obj, IntVar): @@ -2064,11 +2207,11 @@ def _SetObjective(self, obj, minimize): else: raise TypeError("TypeError: " + str(obj) + " is not a valid objective") - def Minimize(self, obj): + def Minimize(self, obj: Union[IntVar, LinearExpr, IntegralT]): """Sets the objective of the model to minimize(obj).""" self._SetObjective(obj, minimize=True) - def Maximize(self, obj): + def Maximize(self, obj: Union[IntVar, LinearExpr, IntegralT]): """Sets the objective of the model to maximize(obj).""" self._SetObjective(obj, minimize=False) @@ -2128,7 +2271,7 @@ def AssertIsBooleanVariable(self, x): elif not isinstance(x, _NotBooleanVariable): raise TypeError("TypeError: " + str(x) + " is not a boolean variable") - def AddHint(self, var, value): + def AddHint(self, var, value: int): """Adds 'var == value' as a hint to the solver.""" self.__model.solution_hint.vars.append(self.GetOrMakeIndex(var)) self.__model.solution_hint.values.append(value) @@ -2211,7 +2354,7 @@ def EvaluateBooleanExpression(literal, solution): raise TypeError(f"Cannot interpret {literal} as a boolean expression.") -class CpSolver(object): +class CpSolver: """Main solver class. The purpose of this class is to search for a solution to the model provided @@ -2222,14 +2365,21 @@ class CpSolver(object): about the solve procedure. """ + __solution: Optional[cp_model_pb2.CpSolverResponse] + parameters: sat_parameters_pb2.SatParameters + def __init__(self): - self.__solution: Optional[cp_model_pb2.CpSolverResponse] = None + self.__solution = None self.parameters = sat_parameters_pb2.SatParameters() self.log_callback = None self.__solve_wrapper: swig_helper.SolveWrapper = None self.__lock = threading.Lock() - def Solve(self, model, solution_callback=None): + def Solve( + self, + model: CpModel, + solution_callback: Optional["CpSolverSolutionCallback"] = None, + ): """Solves a problem and passes each solution to the callback if not null.""" with self.__lock: self.__solve_wrapper = swig_helper.SolveWrapper()