diff --git a/Include/abstract.h b/Include/abstract.h index b4c2bedef442bf..3b4ba2d10142de 100644 --- a/Include/abstract.h +++ b/Include/abstract.h @@ -397,6 +397,13 @@ PyAPI_FUNC(int) PyIter_Check(PyObject *); This function always succeeds. */ PyAPI_FUNC(int) PyAIter_Check(PyObject *); +/* Takes an iterator object and calls its tp_iternext slot, + setting *item to the next value, or to NULL if the iterator + is exhausted. + + returns 0 on success and -1 on error. */ +PyAPI_FUNC(int) PyIter_NextItem(PyObject *iter, PyObject **item); + /* Takes an iterator object and calls its tp_iternext slot, returning the next value. diff --git a/Lib/test/test_capi/test_misc.py b/Lib/test/test_capi/test_misc.py index 037c8112d53e7a..15155afb6d68d5 100644 --- a/Lib/test/test_capi/test_misc.py +++ b/Lib/test/test_capi/test_misc.py @@ -410,6 +410,41 @@ def __delitem__(self, index): _testcapi.sequence_del_slice(mapping, 1, 3) self.assertEqual(mapping, {1: 'a', 2: 'b', 3: 'c'}) + def run_iter_api_test(self, next_func): + inputs = [ (), (1,2,3), + [], [1,2,3]] + + for inp in inputs: + items = [] + it = iter(inp) + while (item := next_func(it)) is not None: + items.append(item) + self.assertEqual(items, list(inp)) + + class Broken: + def __init__(self): + self.count = 0 + + def __next__(self): + if self.count < 3: + self.count += 1 + return self.count + else: + raise TypeError('bad type') + + it = Broken() + self.assertEqual(next_func(it), 1) + self.assertEqual(next_func(it), 2) + self.assertEqual(next_func(it), 3) + with self.assertRaisesRegex(TypeError, 'bad type'): + next_func(it) + + def test_iter_next(self): + self.run_iter_api_test(_testcapi.call_pyiter_next) + + def test_iter_nextitem(self): + self.run_iter_api_test(_testcapi.call_pyiter_nextitem) + @unittest.skipUnless(hasattr(_testcapi, 'negative_refcount'), 'need _testcapi.negative_refcount') def test_negative_refcount(self): diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c index d7c89f48f792ed..a662053a99e41d 100644 --- a/Modules/_testcapimodule.c +++ b/Modules/_testcapimodule.c @@ -282,6 +282,41 @@ dict_getitem_knownhash(PyObject *self, PyObject *args) return Py_XNewRef(result); } +static PyObject* +call_pyiter_next(PyObject* self, PyObject *args) +{ + PyObject *iter; + if (!PyArg_ParseTuple(args, "O:call_pyiter_next", &iter)) { + return NULL; + } + assert(PyIter_Check(iter) || PyAIter_Check(iter)); + PyObject *item = PyIter_Next(iter); + if (item == NULL && !PyErr_Occurred()) { + Py_RETURN_NONE; + } + return item; +} + +static PyObject* +call_pyiter_nextitem(PyObject* self, PyObject *args) +{ + PyObject *iter; + if (!PyArg_ParseTuple(args, "O:call_pyiter_nextitem", &iter)) { + return NULL; + } + assert(PyIter_Check(iter) || PyAIter_Check(iter)); + PyObject *item = NULL; + int ret = PyIter_NextItem(iter, &item); + if (ret < 0) { + return NULL; + } + if (item == NULL && !PyErr_Occurred()) { + Py_RETURN_NONE; + } + return item; +} + + /* Issue #4701: Check that PyObject_Hash implicitly calls * PyType_Ready if it hasn't already been called */ @@ -3286,6 +3321,8 @@ static PyMethodDef TestMethods[] = { {"test_list_api", test_list_api, METH_NOARGS}, {"test_dict_iteration", test_dict_iteration, METH_NOARGS}, {"dict_getitem_knownhash", dict_getitem_knownhash, METH_VARARGS}, + {"call_pyiter_next", call_pyiter_next, METH_VARARGS}, + {"call_pyiter_nextitem", call_pyiter_nextitem, METH_VARARGS}, {"test_lazy_hash_inheritance", test_lazy_hash_inheritance,METH_NOARGS}, {"test_xincref_doesnt_leak",test_xincref_doesnt_leak, METH_NOARGS}, {"test_incref_doesnt_leak", test_incref_doesnt_leak, METH_NOARGS}, diff --git a/Modules/_xxtestfuzz/fuzzer.c b/Modules/_xxtestfuzz/fuzzer.c index 37d402824853f0..c7f8244c5d4d4f 100644 --- a/Modules/_xxtestfuzz/fuzzer.c +++ b/Modules/_xxtestfuzz/fuzzer.c @@ -377,7 +377,7 @@ static int fuzz_csv_reader(const char* data, size_t size) { if (reader) { /* Consume all of the reader as an iterator */ PyObject* parsed_line; - while ((parsed_line = PyIter_Next(reader))) { + while (PyIter_NextItem(reader, &parsed_line) == 0) { Py_DECREF(parsed_line); } } diff --git a/Objects/abstract.c b/Objects/abstract.c index e95785900c9c5f..3fdb5ed9a9b9e2 100644 --- a/Objects/abstract.c +++ b/Objects/abstract.c @@ -2833,6 +2833,29 @@ PyAIter_Check(PyObject *obj) tp->tp_as_async->am_anext != &_PyObject_NextNotImplemented); } +/* Set *item to the next item. Return 0 on success and -1 on error. + * If the iteration terminates normally, set *item to NULL and clear + * the PyExc_StopIteration exception (if it was set). + */ +int +PyIter_NextItem(PyObject *iter, PyObject **item) +{ + *item = (*Py_TYPE(iter)->tp_iternext)(iter); + if (*item == NULL) { + PyThreadState *tstate = _PyThreadState_GET(); + if (_PyErr_Occurred(tstate)) { + if (_PyErr_ExceptionMatches(tstate, PyExc_StopIteration)) { + _PyErr_Clear(tstate); + *item = NULL; + } + else { + return -1; + } + } + } + return 0; +} + /* Return next item. * If an error occurs, return NULL. PyErr_Occurred() will be true. * If the iteration terminates normally, return NULL and clear the @@ -2844,14 +2867,8 @@ PyObject * PyIter_Next(PyObject *iter) { PyObject *result; - result = (*Py_TYPE(iter)->tp_iternext)(iter); - if (result == NULL) { - PyThreadState *tstate = _PyThreadState_GET(); - if (_PyErr_Occurred(tstate) - && _PyErr_ExceptionMatches(tstate, PyExc_StopIteration)) - { - _PyErr_Clear(tstate); - } + if (PyIter_NextItem(iter, &result) < 0) { + return NULL; } return result; }