diff --git a/python/src/pyapproxmc.cpp b/python/src/pyapproxmc.cpp index 80a727d..b3bacd6 100644 --- a/python/src/pyapproxmc.cpp +++ b/python/src/pyapproxmc.cpp @@ -208,6 +208,138 @@ static PyObject* add_clause(Counter *self, PyObject *args, PyObject *kwds) } +template +static int _add_clauses_from_array(Counter *self, const size_t array_length, const T *array) +{ + if (array_length == 0) { + return 1; + } + if (array[array_length - 1] != 0) { + PyErr_SetString(PyExc_ValueError, "last clause not terminated by zero"); + return 0; + } + size_t k = 0; + long val = 0; + std::vector& lits = self->tmp_cl_lits; + for (val = (long) array[k]; k < array_length; val = (long) array[++k]) { + lits.clear(); + long int max_var = 0; + for (; k < array_length && val != 0; val = (long) array[++k]) { + long var; + bool sign; + if (val > std::numeric_limits::max()/2 + || val < std::numeric_limits::min()/2 + ) { + PyErr_Format(PyExc_ValueError, "integer %ld is too small or too large", val); + return 0; + } + + sign = (val < 0); + var = std::abs(val) - 1; + max_var = std::max(var, max_var); + + lits.push_back(CMSat::Lit(var, sign)); + } + if (!lits.empty()) { + if (max_var >= (long int)self->appmc->nVars()) { + self->appmc->new_vars(max_var-(long int)self->appmc->nVars()+1); + } + self->appmc->add_clause(lits); + } + } + return 1; +} + +static int _add_clauses_from_buffer(Counter *self, Py_buffer *view) +{ + if (view->ndim != 1) { + PyErr_Format(PyExc_ValueError, "invalid clause array: expected 1-D array, got %d-D", view->ndim); + return 0; + } + if (strcmp(view->format, "i") != 0 && strcmp(view->format, "l") != 0 && strcmp(view->format, "q") != 0) { + PyErr_Format(PyExc_ValueError, "invalid clause array: invalid format '%s'", view->format); + return 0; + } + + void * array_address = view->buf; + size_t itemsize = view->itemsize; + size_t array_length = view->len / itemsize; + + if (itemsize == sizeof(int)) { + return _add_clauses_from_array(self, array_length, (const int *) array_address); + } + if (itemsize == sizeof(long)) { + return _add_clauses_from_array(self, array_length, (const long *) array_address); + } + if (itemsize == sizeof(long long)) { + return _add_clauses_from_array(self, array_length, (const long long *) array_address); + } + PyErr_Format(PyExc_ValueError, "invalid clause array: invalid itemsize '%ld'", itemsize); + return 0; +} + +PyDoc_STRVAR(add_clauses_doc, +"add_clauses(clauses)\n\ +Add iterable of clauses to the solver.\n\ +\n\ +:param clauses: List of clauses. Each clause contains literals (ints)\n\ + Alternatively, this can be a flat array.array or other contiguous\n\ + buffer (format 'i', 'l', or 'q') of zero separated and terminated\n\ + clauses of literals (ints).\n\ +:type clauses: or \n\ +:return: None\n\ +:rtype: " +); + +static PyObject* add_clauses(Counter *self, PyObject *args, PyObject *kwds) +{ + static char const* kwlist[] = {"clauses", NULL}; + PyObject *clauses; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", const_cast(kwlist), &clauses)) { + return NULL; + } + + if (PyObject_CheckBuffer(clauses)) { + Py_buffer view; + memset(&view, 0, sizeof(view)); + if (PyObject_GetBuffer(clauses, &view, PyBUF_CONTIG_RO | PyBUF_FORMAT) != 0) { + return NULL; + } + + int ret = _add_clauses_from_buffer(self, &view); + PyBuffer_Release(&view); + + if (ret == 0 || PyErr_Occurred()) { + return 0; + } + Py_INCREF(Py_None); + return Py_None; + } + + PyObject *iterator = PyObject_GetIter(clauses); + if (iterator == NULL) { + PyErr_SetString(PyExc_TypeError, "iterable object expected"); + return NULL; + } + + PyObject *clause; + while ((clause = PyIter_Next(iterator)) != NULL) { + _add_clause(self, clause); + /* release reference when done */ + Py_DECREF(clause); + } + + /* release reference when done */ + Py_DECREF(iterator); + if (PyErr_Occurred()) { + return NULL; + } + + Py_INCREF(Py_None); + return Py_None; +} + + static void get_cnf_from_arjun(Counter* self) { const uint32_t orig_num_vars = self->arjun->get_orig_num_vars(); @@ -357,6 +489,7 @@ static PyObject* count(Counter *self, PyObject *args, PyObject *kwds) static PyMethodDef Counter_methods[] = { {"count", (PyCFunction) count, METH_VARARGS | METH_KEYWORDS, count_doc}, {"add_clause",(PyCFunction) add_clause, METH_VARARGS | METH_KEYWORDS, add_clause_doc}, + {"add_clauses", (PyCFunction) add_clauses, METH_VARARGS | METH_KEYWORDS, add_clauses_doc}, {NULL, NULL} // Sentinel };