Skip to content

Commit

Permalink
Allowing adding many clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
msoos committed Oct 15, 2023
1 parent c4b782a commit 9a7e9d7
Showing 1 changed file with 133 additions and 0 deletions.
133 changes: 133 additions & 0 deletions python/src/pyapproxmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,138 @@ static PyObject* add_clause(Counter *self, PyObject *args, PyObject *kwds)

}

template <typename T>
static int _add_clauses_from_array(Counter *self, const size_t array_length, const T *array)
{
if (array_length == 0) {
return 1;
}
if (array[array_length - 1] != 0) {
PyErr_SetString(PyExc_ValueError, "last clause not terminated by zero");
return 0;
}
size_t k = 0;
long val = 0;
std::vector<CMSat::Lit>& lits = self->tmp_cl_lits;
for (val = (long) array[k]; k < array_length; val = (long) array[++k]) {
lits.clear();
long int max_var = 0;
for (; k < array_length && val != 0; val = (long) array[++k]) {
long var;
bool sign;
if (val > std::numeric_limits<int>::max()/2
|| val < std::numeric_limits<int>::min()/2
) {
PyErr_Format(PyExc_ValueError, "integer %ld is too small or too large", val);
return 0;
}

sign = (val < 0);
var = std::abs(val) - 1;
max_var = std::max(var, max_var);

lits.push_back(CMSat::Lit(var, sign));
}
if (!lits.empty()) {
if (max_var >= (long int)self->appmc->nVars()) {
self->appmc->new_vars(max_var-(long int)self->appmc->nVars()+1);
}
self->appmc->add_clause(lits);
}
}
return 1;
}

static int _add_clauses_from_buffer(Counter *self, Py_buffer *view)
{
if (view->ndim != 1) {
PyErr_Format(PyExc_ValueError, "invalid clause array: expected 1-D array, got %d-D", view->ndim);
return 0;
}
if (strcmp(view->format, "i") != 0 && strcmp(view->format, "l") != 0 && strcmp(view->format, "q") != 0) {
PyErr_Format(PyExc_ValueError, "invalid clause array: invalid format '%s'", view->format);
return 0;
}

void * array_address = view->buf;
size_t itemsize = view->itemsize;
size_t array_length = view->len / itemsize;

if (itemsize == sizeof(int)) {
return _add_clauses_from_array(self, array_length, (const int *) array_address);
}
if (itemsize == sizeof(long)) {
return _add_clauses_from_array(self, array_length, (const long *) array_address);
}
if (itemsize == sizeof(long long)) {
return _add_clauses_from_array(self, array_length, (const long long *) array_address);
}
PyErr_Format(PyExc_ValueError, "invalid clause array: invalid itemsize '%ld'", itemsize);
return 0;
}

PyDoc_STRVAR(add_clauses_doc,
"add_clauses(clauses)\n\
Add iterable of clauses to the solver.\n\
\n\
:param clauses: List of clauses. Each clause contains literals (ints)\n\
Alternatively, this can be a flat array.array or other contiguous\n\
buffer (format 'i', 'l', or 'q') of zero separated and terminated\n\
clauses of literals (ints).\n\
:type clauses: <list> or <array.array>\n\
:return: None\n\
:rtype: <None>"
);

static PyObject* add_clauses(Counter *self, PyObject *args, PyObject *kwds)
{
static char const* kwlist[] = {"clauses", NULL};
PyObject *clauses;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", const_cast<char**>(kwlist), &clauses)) {
return NULL;
}

if (PyObject_CheckBuffer(clauses)) {
Py_buffer view;
memset(&view, 0, sizeof(view));
if (PyObject_GetBuffer(clauses, &view, PyBUF_CONTIG_RO | PyBUF_FORMAT) != 0) {
return NULL;
}

int ret = _add_clauses_from_buffer(self, &view);
PyBuffer_Release(&view);

if (ret == 0 || PyErr_Occurred()) {
return 0;
}
Py_INCREF(Py_None);
return Py_None;
}

PyObject *iterator = PyObject_GetIter(clauses);
if (iterator == NULL) {
PyErr_SetString(PyExc_TypeError, "iterable object expected");
return NULL;
}

PyObject *clause;
while ((clause = PyIter_Next(iterator)) != NULL) {
_add_clause(self, clause);
/* release reference when done */
Py_DECREF(clause);
}

/* release reference when done */
Py_DECREF(iterator);
if (PyErr_Occurred()) {
return NULL;
}

Py_INCREF(Py_None);
return Py_None;
}


static void get_cnf_from_arjun(Counter* self)
{
const uint32_t orig_num_vars = self->arjun->get_orig_num_vars();
Expand Down Expand Up @@ -357,6 +489,7 @@ static PyObject* count(Counter *self, PyObject *args, PyObject *kwds)
static PyMethodDef Counter_methods[] = {
{"count", (PyCFunction) count, METH_VARARGS | METH_KEYWORDS, count_doc},
{"add_clause",(PyCFunction) add_clause, METH_VARARGS | METH_KEYWORDS, add_clause_doc},
{"add_clauses", (PyCFunction) add_clauses, METH_VARARGS | METH_KEYWORDS, add_clauses_doc},
{NULL, NULL} // Sentinel
};

Expand Down

0 comments on commit 9a7e9d7

Please sign in to comment.