Skip to content

Commit

Permalink
pythongh-105201: Add PyIter_NextItem to replace PyIter_Next which has…
Browse files Browse the repository at this point in the history
… an ambiguous return value
  • Loading branch information
iritkatriel committed Jun 1, 2023
1 parent a241003 commit 5e91524
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 9 deletions.
7 changes: 7 additions & 0 deletions Include/abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,13 @@ PyAPI_FUNC(int) PyIter_Check(PyObject *);
This function always succeeds. */
PyAPI_FUNC(int) PyAIter_Check(PyObject *);

/* Takes an iterator object and calls its tp_iternext slot,
setting *item to the next value, or to NULL if the iterator
is exhausted.
returns 0 on success and -1 on error. */
PyAPI_FUNC(int) PyIter_NextItem(PyObject *iter, PyObject **item);

/* Takes an iterator object and calls its tp_iternext slot,
returning the next value.
Expand Down
35 changes: 35 additions & 0 deletions Lib/test/test_capi/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,41 @@ def __delitem__(self, index):
_testcapi.sequence_del_slice(mapping, 1, 3)
self.assertEqual(mapping, {1: 'a', 2: 'b', 3: 'c'})

def run_iter_api_test(self, next_func):
inputs = [ (), (1,2,3),
[], [1,2,3]]

for inp in inputs:
items = []
it = iter(inp)
while (item := next_func(it)) is not None:
items.append(item)
self.assertEqual(items, list(inp))

class Broken:
def __init__(self):
self.count = 0

def __next__(self):
if self.count < 3:
self.count += 1
return self.count
else:
raise TypeError('bad type')

it = Broken()
self.assertEqual(next_func(it), 1)
self.assertEqual(next_func(it), 2)
self.assertEqual(next_func(it), 3)
with self.assertRaisesRegex(TypeError, 'bad type'):
next_func(it)

def test_iter_next(self):
self.run_iter_api_test(_testcapi.call_pyiter_next)

def test_iter_nextitem(self):
self.run_iter_api_test(_testcapi.call_pyiter_nextitem)

@unittest.skipUnless(hasattr(_testcapi, 'negative_refcount'),
'need _testcapi.negative_refcount')
def test_negative_refcount(self):
Expand Down
37 changes: 37 additions & 0 deletions Modules/_testcapimodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,41 @@ dict_getitem_knownhash(PyObject *self, PyObject *args)
return Py_XNewRef(result);
}

static PyObject*
call_pyiter_next(PyObject* self, PyObject *args)
{
PyObject *iter;
if (!PyArg_ParseTuple(args, "O:call_pyiter_next", &iter)) {
return NULL;
}
assert(PyIter_Check(iter) || PyAIter_Check(iter));
PyObject *item = PyIter_Next(iter);
if (item == NULL && !PyErr_Occurred()) {
Py_RETURN_NONE;
}
return item;
}

static PyObject*
call_pyiter_nextitem(PyObject* self, PyObject *args)
{
PyObject *iter;
if (!PyArg_ParseTuple(args, "O:call_pyiter_nextitem", &iter)) {
return NULL;
}
assert(PyIter_Check(iter) || PyAIter_Check(iter));
PyObject *item = NULL;
int ret = PyIter_NextItem(iter, &item);
if (ret < 0) {
return NULL;
}
if (item == NULL && !PyErr_Occurred()) {
Py_RETURN_NONE;
}
return item;
}


/* Issue #4701: Check that PyObject_Hash implicitly calls
* PyType_Ready if it hasn't already been called
*/
Expand Down Expand Up @@ -3286,6 +3321,8 @@ static PyMethodDef TestMethods[] = {
{"test_list_api", test_list_api, METH_NOARGS},
{"test_dict_iteration", test_dict_iteration, METH_NOARGS},
{"dict_getitem_knownhash", dict_getitem_knownhash, METH_VARARGS},
{"call_pyiter_next", call_pyiter_next, METH_VARARGS},
{"call_pyiter_nextitem", call_pyiter_nextitem, METH_VARARGS},
{"test_lazy_hash_inheritance", test_lazy_hash_inheritance,METH_NOARGS},
{"test_xincref_doesnt_leak",test_xincref_doesnt_leak, METH_NOARGS},
{"test_incref_doesnt_leak", test_incref_doesnt_leak, METH_NOARGS},
Expand Down
2 changes: 1 addition & 1 deletion Modules/_xxtestfuzz/fuzzer.c
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ static int fuzz_csv_reader(const char* data, size_t size) {
if (reader) {
/* Consume all of the reader as an iterator */
PyObject* parsed_line;
while ((parsed_line = PyIter_Next(reader))) {
while (PyIter_NextItem(reader, &parsed_line) == 0) {
Py_DECREF(parsed_line);
}
}
Expand Down
33 changes: 25 additions & 8 deletions Objects/abstract.c
Original file line number Diff line number Diff line change
Expand Up @@ -2833,6 +2833,29 @@ PyAIter_Check(PyObject *obj)
tp->tp_as_async->am_anext != &_PyObject_NextNotImplemented);
}

/* Set *item to the next item. Return 0 on success and -1 on error.
* If the iteration terminates normally, set *item to NULL and clear
* the PyExc_StopIteration exception (if it was set).
*/
int
PyIter_NextItem(PyObject *iter, PyObject **item)
{
*item = (*Py_TYPE(iter)->tp_iternext)(iter);
if (*item == NULL) {
PyThreadState *tstate = _PyThreadState_GET();
if (_PyErr_Occurred(tstate)) {
if (_PyErr_ExceptionMatches(tstate, PyExc_StopIteration)) {
_PyErr_Clear(tstate);
*item = NULL;
}
else {
return -1;
}
}
}
return 0;
}

/* Return next item.
* If an error occurs, return NULL. PyErr_Occurred() will be true.
* If the iteration terminates normally, return NULL and clear the
Expand All @@ -2844,14 +2867,8 @@ PyObject *
PyIter_Next(PyObject *iter)
{
PyObject *result;
result = (*Py_TYPE(iter)->tp_iternext)(iter);
if (result == NULL) {
PyThreadState *tstate = _PyThreadState_GET();
if (_PyErr_Occurred(tstate)
&& _PyErr_ExceptionMatches(tstate, PyExc_StopIteration))
{
_PyErr_Clear(tstate);
}
if (PyIter_NextItem(iter, &result) < 0) {
return NULL;
}
return result;
}
Expand Down

0 comments on commit 5e91524

Please sign in to comment.