Skip to content

Commit

Permalink
make PyErrState thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Oct 30, 2024
1 parent f74d374 commit 08beaa5
Showing 1 changed file with 52 additions and 25 deletions.
77 changes: 52 additions & 25 deletions src/err/err_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::cell::UnsafeCell;
use std::{
cell::UnsafeCell,
sync::{Mutex, Once},
thread::ThreadId,
};

use crate::{
exceptions::{PyBaseException, PyTypeError},
Expand All @@ -11,13 +15,14 @@ use crate::{
pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
//
// The state is temporarily removed from the PyErr during normalization, to avoid
// concurrent modifications.
normalized: Once,
// Guard against re-entrancy when normalizing the exception state.
normalizing_thread: Mutex<Option<ThreadId>>,
inner: UnsafeCell<Option<PyErrStateInner>>,
}

// The inner value is only accessed through ways that require the gil is held.
// Safety: The inner value is protected by locking to ensure that only the normalized state is
// handed out as a reference.
unsafe impl Send for PyErrState {}
unsafe impl Sync for PyErrState {}

Expand Down Expand Up @@ -48,17 +53,22 @@ impl PyErrState {

fn from_inner(inner: PyErrStateInner) -> Self {
Self {
normalized: Once::new(),
normalizing_thread: Mutex::new(None),
inner: UnsafeCell::new(Some(inner)),
}
}

#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
if let Some(PyErrStateInner::Normalized(n)) = unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
return n;
if self.normalized.is_completed() {
match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => return n,
_ => unreachable!(),
}
}

self.make_normalized(py)
Expand All @@ -69,25 +79,42 @@ impl PyErrState {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.
// - State is set to None during the normalization process, so that a second
// concurrent normalization attempt will panic before changing anything.

// FIXME: this needs to be rewritten to deal with free-threaded Python
// see https://github.com/PyO3/pyo3/issues/4584
// Guard against re-entrant normalization, because `Once` does not provide
// re-entrancy guarantees.
if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
if *thread == std::thread::current().id() {
panic!("Re-entrant normalization of PyErrState detected");
}
}

let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};
self.normalized.call_once(|| {
self.normalizing_thread
.lock()
.unwrap()
.replace(std::thread::current().id());

// Safety: no other thread can access the inner value while we are normalizing it.
let state = unsafe {
(*self.inner.get())
.take()
.expect("Cannot normalize a PyErr while already normalizing it.")
};

unsafe {
let self_state = &mut *self.inner.get();
*self_state = Some(PyErrStateInner::Normalized(state.normalize(py)));
match self_state {
Some(PyErrStateInner::Normalized(n)) => n,
_ => unreachable!(),
let normalized_state = PyErrStateInner::Normalized(state.normalize(py));

// Safety: no other thread can access the inner value while we are normalizing it.
unsafe {
*self.inner.get() = Some(normalized_state);
}
});

match unsafe {
// Safety: self.inner will never be written again once normalized.
&*self.inner.get()
} {
Some(PyErrStateInner::Normalized(n)) => return n,
_ => unreachable!(),
}
}
}
Expand Down

0 comments on commit 08beaa5

Please sign in to comment.