Skip to content

Commit

Permalink
remove style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout4 committed Sep 19, 2023
1 parent 63e3f0a commit 52cd76d
Showing 1 changed file with 29 additions and 34 deletions.
63 changes: 29 additions & 34 deletions cpmpy/expressions/variables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
#-*- coding:utf-8 -*-
##
## variables.py
##
Expand Down Expand Up @@ -49,7 +49,7 @@
import math
import uuid
from collections.abc import Iterable
import warnings # for deprecation warning
import warnings # for deprecation warning
from functools import reduce

import numpy as np
Expand Down Expand Up @@ -115,7 +115,7 @@ def boolvar(shape=1, name=None):
return _BoolVarImpl(name=name)

# create base data
data = np.array([_BoolVarImpl(name=_genname(name, idxs)) for idxs in np.ndindex(shape)]) # repeat new instances
data = np.array([_BoolVarImpl(name=_genname(name, idxs)) for idxs in np.ndindex(shape)]) # repeat new instances
# insert into custom ndarray
return NDVarArray(shape, dtype=object, buffer=data)

Expand Down Expand Up @@ -176,8 +176,7 @@ def intvar(lb, ub, shape=1, name=None):
return _IntVarImpl(lb, ub, name=name)

# create base data
data = np.array(
[_IntVarImpl(lb, ub, name=_genname(name, idxs)) for idxs in np.ndindex(shape)]) # repeat new instances
data = np.array([_IntVarImpl(lb, ub, name=_genname(name, idxs)) for idxs in np.ndindex(shape)]) # repeat new instances
# insert into custom ndarray
return NDVarArray(shape, dtype=object, buffer=data)

Expand Down Expand Up @@ -219,7 +218,6 @@ class NullShapeError(Exception):
"""
Error returned when providing an empty or size 0 shape for numpy arrays of variables
"""

def __init__(self, shape, message="Shape should be non-zero"):
self.shape = shape
self.message = message
Expand All @@ -235,14 +233,12 @@ class _NumVarImpl(Expression):
Abstract class, only mean to be subclassed
"""

def __init__(self, lb, ub, name):
assert (is_num(lb) and is_num(ub))
assert (lb <= ub)
self.lb = lb
self.ub = ub
self.name = name
self.id = uuid.uuid4()
self._value = None

def is_bool(self):
Expand All @@ -268,7 +264,7 @@ def clear(self):
def __repr__(self):
return self.name

# for sets/dicts. Because id's are unique, so is the str repr
# for sets/dicts. Use the unique ID
def __hash__(self):
# for backwards compatability
if not hasattr(self, 'id'):
Expand All @@ -290,9 +286,9 @@ def __init__(self, lb, ub, name=None):

if name is None:
name = "IV{}".format(_IntVarImpl.counter)
_IntVarImpl.counter = _IntVarImpl.counter + 1 # static counter
_IntVarImpl.counter = _IntVarImpl.counter + 1 # static counter

super().__init__(int(lb), int(ub), name=name) # explicit cast: can be numpy
super().__init__(int(lb), int(ub), name=name) # explicit cast: can be numpy

# special casing for intvars (and boolvars)
def __abs__(self):
Expand All @@ -311,12 +307,12 @@ class _BoolVarImpl(_IntVarImpl):
counter = 0

def __init__(self, lb=0, ub=1, name=None):
assert (lb == 0 or lb == 1)
assert (ub == 0 or ub == 1)
assert(lb == 0 or lb == 1)
assert(ub == 0 or ub == 1)

if name is None:
name = "BV{}".format(_BoolVarImpl.counter)
_BoolVarImpl.counter = _BoolVarImpl.counter + 1 # static counter
_BoolVarImpl.counter = _BoolVarImpl.counter + 1 # static counter
_IntVarImpl.__init__(self, lb, ub, name=name)

def is_bool(self):
Expand Down Expand Up @@ -344,9 +340,8 @@ class NegBoolView(_BoolVarImpl):
Do not create this object directly, use the `~` operator instead: `~bv`
"""

def __init__(self, bv):
# assert(isinstance(bv, _BoolVarImpl))
#assert(isinstance(bv, _BoolVarImpl))
self._bv = bv
# as it is always created using the ~ operator (only available for _BoolVarImpl)
# it already comply with the asserts of the __init__ of _BoolVarImpl and can use
Expand Down Expand Up @@ -382,7 +377,6 @@ class NDVarArray(np.ndarray, Expression):
Do not create this object directly, use one of the functions in this module
"""

def __init__(self, shape, **kwargs):
# TODO: global name?
# this is nice and sneaky, 'self' is the list_of_arguments!
Expand Down Expand Up @@ -420,7 +414,7 @@ def __repr__(self):
return super().__repr__()

def __getitem__(self, index):
from .globalfunctions import Element # here to avoid circular
from .globalfunctions import Element # here to avoid circular
# array access, check if variables are used in the indexing

# index is single expression: direct element
Expand All @@ -430,14 +424,14 @@ def __getitem__(self, index):
# multi-dimensional index
if isinstance(index, tuple) and any(isinstance(el, Expression) for el in index):
# find dimension of expression in index
expr_dim = next(dim for dim, idx in enumerate(index) if isinstance(idx, Expression))
arr = self[tuple(index[:expr_dim])] # select remaining dimensions
expr_dim = next(dim for dim,idx in enumerate(index) if isinstance(idx, Expression))
arr = self[tuple(index[:expr_dim])] # select remaining dimensions
index = index[expr_dim:]

# calculate index for flat array
flat_index = index[-1]
for dim, idx in enumerate(index[:-1]):
flat_index += idx * math.prod(arr.shape[dim + 1:])
flat_index += idx * math.prod(arr.shape[dim+1:])
# using index expression as single var for flat array
return Element(arr.flatten(), flat_index)

Expand All @@ -453,13 +447,12 @@ def __getitem__(self, index):
"""
make the given array the first dimension in the returned array
"""

def __axis(self, axis):

arr = self

# correct type and value checks
if not isinstance(axis, int):
if not isinstance(axis,int):
raise TypeError("Axis keyword argument in .sum() should always be an integer")
if axis >= arr.ndim:
raise ValueError("Axis out of range")
Expand All @@ -483,7 +476,7 @@ def sum(self, axis=None, out=None):
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the sum over the whole array
if axis is None: # simple case where we want the sum over the whole array
arr = self.flatten()
return Operator("sum", arr)

Expand All @@ -496,14 +489,15 @@ def sum(self, axis=None, out=None):
# return the NDVarArray that contains the sum constraints
return out


def prod(self, axis=None, out=None):
"""
overwrite np.prod(NDVarArray) as people might use it
"""
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the product over the whole array
if axis is None: # simple case where we want the product over the whole array
arr = self.flatten()
return reduce(lambda a, b: a * b, arr)

Expand All @@ -524,7 +518,7 @@ def max(self, axis=None, out=None):
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the maximum over the whole array
if axis is None: # simple case where we want the maximum over the whole array
arr = self.flatten()
return Maximum(arr)

Expand All @@ -545,7 +539,7 @@ def min(self, axis=None, out=None):
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the Minimum over the whole array
if axis is None: # simple case where we want the Minimum over the whole array
arr = self.flatten()
return Minimum(arr)

Expand All @@ -570,7 +564,7 @@ def any(self, axis=None, out=None):
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the .any() over the whole array
if axis is None: # simple case where we want the .any() over the whole array
arr = self.flatten()
return any(arr)

Expand All @@ -592,7 +586,7 @@ def all(self, axis=None, out=None):
if out is not None:
raise NotImplementedError()

if axis is None: # simple case where we want the .all() over the whole array
if axis is None: # simple case where we want the .all() over the whole array
arr = self.flatten()
return all(arr)

Expand All @@ -608,10 +602,10 @@ def all(self, axis=None, out=None):
# VECTORIZED master function (delegate)
def _vectorized(self, other, attr):
if not isinstance(other, Iterable):
other = [other] * len(self)
other = [other]*len(self)
# this is a bit cryptic, but it calls 'attr' on s with o as arg
# s.__eq__(o) <-> getattr(s, '__eq__')(o)
return cpm_array([getattr(s, attr)(o) for s, o in zip(self, other)])
return cpm_array([getattr(s,attr)(o) for s,o in zip(self, other)])

# VECTORIZED comparisons
def __eq__(self, other):
Expand Down Expand Up @@ -710,12 +704,12 @@ def __rxor__(self, other):
def implies(self, other):
return self._vectorized(other, 'implies')

# in __contains__(self, value) Check membership
#in __contains__(self, value) Check membership
# CANNOT meaningfully overwrite, python always returns True/False
# regardless of what you return in the __contains__ function

# TODO?
# object.__matmul__(self, other)
#object.__matmul__(self, other)


def _genname(basename, idxs):
Expand All @@ -731,4 +725,5 @@ def _genname(basename, idxs):
if basename == None:
return None
stridxs = ",".join(map(str, idxs))
return f"{basename}[{stridxs}]" # "<name>[<idx0>,<idx1>,...]"
return f"{basename}[{stridxs}]" # "<name>[<idx0>,<idx1>,...]"

0 comments on commit 52cd76d

Please sign in to comment.