From b2e7388fa7e687ffb237b0cbf72e02a2ff044182 Mon Sep 17 00:00:00 2001 From: Wout Date: Tue, 17 Oct 2023 13:16:17 +0200 Subject: [PATCH] xor decomposition using sum * recursive xor using sum. --- cpmpy/expressions/globalconstraints.py | 15 +++++---------- tests/test_globalconstraints.py | 2 ++ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/cpmpy/expressions/globalconstraints.py b/cpmpy/expressions/globalconstraints.py index 4dd300e33..db4cc3ff6 100644 --- a/cpmpy/expressions/globalconstraints.py +++ b/cpmpy/expressions/globalconstraints.py @@ -394,16 +394,11 @@ def __init__(self, arg_list): super().__init__("xor", flatargs) def decompose(self): - # there are multiple decompositions possible - # sum(args) mod 2 == 1, for size 2: sum(args) == 1 - # since Xor is logical constraint, the default is a logic decomposition - a0, a1 = self.args[:2] - cons = (a0 | a1) & (~a0 | ~a1) # one true and one false - - # for more than 2 variables, we cascade (decomposed) xors - for arg in self.args[2:]: - cons = (cons | arg) & (~cons | ~arg) - return [cons], [] + # there are multiple decompositions possible, Recursively using sum allows it to be efficient for all solvers. + decomp = [sum(self.args[:2]) == 1] + if len(self.args) > 2: + decomp = Xor([decomp,self.args[2:]]).decompose()[0] + return decomp, [] def value(self): return sum(argval(a) for a in self.args) % 2 == 1 diff --git a/tests/test_globalconstraints.py b/tests/test_globalconstraints.py index a4c3db386..fc237f575 100644 --- a/tests/test_globalconstraints.py +++ b/tests/test_globalconstraints.py @@ -318,6 +318,8 @@ def test_not_xor(self): self.assertFalse(cp.Xor(bv).value()) nbNotModels = cp.Model(~cp.Xor(bv)).solveAll(display=lambda: self.assertFalse(cp.Xor(bv).value())) nbModels = cp.Model(cp.Xor(bv)).solveAll(display=lambda: self.assertTrue(cp.Xor(bv).value())) + nbDecompModels = cp.Model(cp.Xor(bv).decompose()).solveAll(display=lambda: self.assertTrue(cp.Xor(bv).value())) + self.assertEqual(nbDecompModels,nbModels) total = cp.Model(bv == bv).solveAll() self.assertEqual(str(total), str(nbModels + nbNotModels))