Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IgnaceBleukx committed Feb 26, 2024
1 parent 9215a25 commit 7c7d8d6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cpmpy/tools/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def write_cnf(model, fname=None):
vars = get_variables(constraints)
mapping = {v : i+1 for i, v in enumerate(vars)}

out = f"p {len(vars)} {len(constraints)}\n"
out = f"p cnf {len(vars)} {len(constraints)}\n"
for cons in constraints:

if isinstance(cons, _BoolVarImpl):
Expand Down Expand Up @@ -103,9 +103,9 @@ def read_cnf(fname, sep=None):
clause = []
for i, var_idx in enumerate(map(int, str_idxes)):
if var_idx > 0: # boolvar
clause.append(bvs[i-1])
clause.append(bvs[var_idx-1])
elif var_idx < 0: # neg boolvar
clause.append(bvs[(-i)-1])
clause.append(~bvs[(-var_idx)-1])
elif var_idx == 0: # end of clause
assert i == len(str_idxes)-1, f"Can only have '0' at end of a clause, but got 0 at index {i} in clause {str_idxes}"
m += cp.any(clause)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_tool_cnf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest
import tempfile

import cpmpy as cp
from cpmpy.tools.cnf import read_cnf, write_cnf
from cpmpy.transformations.get_variables import get_variables_model
class CNFTool(unittest.TestCase):

def test_read_cnf(self):

tmp_file = tempfile.NamedTemporaryFile()

"""
a | b | c,
~b | ~c,
~a
"""
cnf_txt = "p cnf 3 3\n1 2 3 0\n-2 -3 0\n-1 0\n"
with open(tmp_file.name, "w") as f:
f.write(cnf_txt)

model = read_cnf(tmp_file.name)

vars = sorted(get_variables_model(model), key=str)

sols = set()
addsol = lambda : sols.add(tuple([v.value() for v in vars]))

self.assertEqual(model.solveAll(display=addsol), 2)
self.assertSetEqual(sols, {(False, False, True), (False, True, False)})

def test_write_cnf(self):

a,b,c = [cp.boolvar(name=n) for n in "abc"]

m = cp.Model()
m += cp.any([a,b,c])
m += b.implies(~c)
m += a <= 0

cnf_txt = write_cnf(m)
gt_cnf = "p cnf 3 3\n1 2 3 0\n-2 -3 0\n-1 0\n"

self.assertEqual(cnf_txt, gt_cnf)

0 comments on commit 7c7d8d6

Please sign in to comment.