Skip to content

Commit

Permalink
Use plain dict iterator when not on free-threaded & not a subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
bschoenmaeckers committed Aug 29, 2024
1 parent 52c242d commit 0b8f4c6
Showing 1 changed file with 92 additions and 18 deletions.
110 changes: 92 additions & 18 deletions src/types/dict.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use super::{PyIterator, PyMapping, PyTuple};
use crate::err::{self, PyErr, PyResult};
use crate::ffi::Py_ssize_t;
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::instance::{Borrowed, Bound};
use crate::py_result_ext::PyResultExt;
use crate::types::any::PyAnyMethods;
use crate::types::{PyAny, PyList};
use crate::types::{PyAny, PyAnyMethods, PyIterator, PyList, PyMapping, PyTuple, PyTupleMethods};
use crate::{ffi, Python, ToPyObject};

/// Represents a Python `dict`.
Expand Down Expand Up @@ -370,23 +368,86 @@ fn dict_len(dict: &Bound<'_, PyDict>) -> Py_ssize_t {
}

/// PyO3 implementation of an iterator for a Python `dict` object.
pub struct BoundDictIterator<'py> {
iter: Bound<'py, PyIterator>,
remaining: usize,
pub enum BoundDictIterator<'py> {
ItemIter {
iter: Bound<'py, PyIterator>,
remaining: ffi::Py_ssize_t,
},
#[cfg(not(Py_GIL_DISABLED))]
DictIter {
dict: Bound<'py, PyDict>,
ppos: ffi::Py_ssize_t,
di_used: ffi::Py_ssize_t,
remaining: ffi::Py_ssize_t,
},
}

impl<'py> Iterator for BoundDictIterator<'py> {
type Item = (Bound<'py, PyAny>, Bound<'py, PyAny>);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.remaining = self.remaining.saturating_sub(1);
self.iter.next().map(Result::unwrap).map(|pair| {
let tuple = pair.downcast::<PyTuple>().unwrap();
let key = tuple.get_item(0).unwrap();
let value = tuple.get_item(1).unwrap();
(key, value)
})
match self {
BoundDictIterator::ItemIter { iter, remaining } => {
*remaining = remaining.saturating_sub(1);
iter.next().map(Result::unwrap).map(|tuple| {
let tuple = tuple.downcast::<PyTuple>().unwrap();
let key = tuple.get_item(0).unwrap();
let value = tuple.get_item(1).unwrap();
(key, value)
})
}
#[cfg(not(Py_GIL_DISABLED))]
BoundDictIterator::DictIter {
dict,
ppos,
di_used,
remaining,
} => {
let ma_used = dict_len(dict);

// These checks are similar to what CPython does.
//
// If the dimension of the dict changes e.g. key-value pairs are removed
// or added during iteration, this will panic next time when `next` is called
if *di_used != ma_used {
*di_used = -1;
panic!("dictionary changed size during iteration");
};

// If the dict is changed in such a way that the length remains constant
// then this will panic at the end of iteration - similar to this:
//
// d = {"a":1, "b":2, "c": 3}
//
// for k, v in d.items():
// d[f"{k}_"] = 4
// del d[k]
// print(k)
//
if *remaining == -1 {
*di_used = -1;
panic!("dictionary keys changed during iteration");
};

let mut key: *mut ffi::PyObject = std::ptr::null_mut();
let mut value: *mut ffi::PyObject = std::ptr::null_mut();

if unsafe { ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) } != 0 {
*remaining -= 1;
let py = dict.py();
// Safety:
// - PyDict_Next returns borrowed values
// - we have already checked that `PyDict_Next` succeeded, so we can assume these to be non-null
Some((
unsafe { key.assume_borrowed_unchecked(py) }.to_owned(),
unsafe { value.assume_borrowed_unchecked(py) }.to_owned(),
))
} else {
None
}
}
}
}

#[inline]
Expand All @@ -398,17 +459,31 @@ impl<'py> Iterator for BoundDictIterator<'py> {

impl<'py> ExactSizeIterator for BoundDictIterator<'py> {
fn len(&self) -> usize {
self.remaining
match self {
BoundDictIterator::ItemIter { remaining, .. } => *remaining as usize,
#[cfg(not(Py_GIL_DISABLED))]
BoundDictIterator::DictIter { remaining, .. } => *remaining as usize,
}
}
}

impl<'py> BoundDictIterator<'py> {
fn new(dict: Bound<'py, PyDict>) -> Self {
let remaining = dict_len(&dict) as usize;
let remaining = dict_len(&dict);

#[cfg(not(Py_GIL_DISABLED))]
if dict.is_exact_instance_of::<PyDict>() {
return BoundDictIterator::DictIter {
dict,
ppos: 0,
di_used: remaining,
remaining,
};
};

let items = dict.call_method0(intern!(dict.py(), "items")).unwrap();
let iter = PyIterator::from_object(&items).unwrap();

BoundDictIterator { iter, remaining }
BoundDictIterator::ItemIter { iter, remaining }
}
}

Expand Down Expand Up @@ -481,7 +556,6 @@ mod borrowed_iter {
impl<'a, 'py> BorrowedDictIter<'a, 'py> {
pub(super) fn new(dict: Borrowed<'a, 'py, PyDict>) -> Self {
let len = dict_len(&dict);

BorrowedDictIter { dict, ppos: 0, len }
}
}
Expand Down

0 comments on commit 0b8f4c6

Please sign in to comment.