diff --git a/Cargo.toml b/Cargo.toml index cae2692..216867b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ bytes = "1" futures-core = "0.3" http = "1" http-body = "1" -parking_lot = "0.12" +parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] } sqlx = { version = "0.7", default-features = false } thiserror = "1" tower-layer = "0.3" diff --git a/src/extension.rs b/src/extension.rs index 485a75b..196575d 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -1,31 +1,36 @@ +use std::sync::Arc; + +use parking_lot::{lock_api::ArcMutexGuard, Mutex, RawMutex}; use sqlx::Transaction; -use crate::{ - slot::{Lease, Slot}, - Error, Marker, State, -}; +use crate::{Error, Marker, State}; /// The request extension. pub(crate) struct Extension { - slot: Slot>, + slot: Arc>>, } impl Extension { pub(crate) fn new(state: State) -> Self { - let slot = Slot::new(LazyTransaction::new(state)); + let slot = Arc::new(Mutex::new(LazyTransaction::new(state))); Self { slot } } - pub(crate) async fn acquire(&self) -> Result>, Error> { - let mut tx = self.slot.lease().ok_or(Error::OverlappingExtractors)?; + pub(crate) async fn acquire( + &self, + ) -> Result>, Error> { + let mut tx = self + .slot + .try_lock_arc() + .ok_or(Error::OverlappingExtractors)?; tx.acquire().await?; Ok(tx) } pub(crate) async fn resolve(&self) -> Result<(), sqlx::Error> { - if let Some(tx) = self.slot.lease() { - tx.steal().resolve().await?; + if let Some(mut tx) = self.slot.try_lock_arc() { + tx.resolve().await?; } Ok(()) } @@ -49,6 +54,7 @@ enum LazyTransactionState { Acquired { tx: Transaction<'static, DB::Driver>, }, + Resolved, } impl LazyTransaction { @@ -62,6 +68,7 @@ impl LazyTransaction { panic!("BUG: exposed unacquired LazyTransaction") } LazyTransactionState::Acquired { tx } => tx, + LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"), } } @@ -71,10 +78,11 @@ impl LazyTransaction { panic!("BUG: exposed unacquired LazyTransaction") } LazyTransactionState::Acquired { tx } => tx, + LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"), } } - async fn acquire(&mut self) -> Result<(), sqlx::Error> { + async fn acquire(&mut self) -> Result<(), Error> { match &self.0 { LazyTransactionState::Unacquired { state } => { let tx = state.transaction().await?; @@ -82,22 +90,24 @@ impl LazyTransaction { Ok(()) } LazyTransactionState::Acquired { .. } => Ok(()), + LazyTransactionState::Resolved => Err(Error::OverlappingExtractors), } } - pub(crate) async fn resolve(self) -> Result<(), sqlx::Error> { - match self.0 { - LazyTransactionState::Unacquired { .. } => Ok(()), + pub(crate) async fn resolve(&mut self) -> Result<(), sqlx::Error> { + match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) { + LazyTransactionState::Unacquired { .. } | LazyTransactionState::Resolved => Ok(()), LazyTransactionState::Acquired { tx } => tx.commit().await, } } - pub(crate) async fn commit(self) -> Result<(), sqlx::Error> { - match self.0 { + pub(crate) async fn commit(&mut self) -> Result<(), sqlx::Error> { + match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) { LazyTransactionState::Unacquired { .. } => { panic!("BUG: tried to commit unacquired transaction") } LazyTransactionState::Acquired { tx } => tx.commit().await, + LazyTransactionState::Resolved => panic!("BUG: tried to commit resolved transaction"), } } } diff --git a/src/lib.rs b/src/lib.rs index 68eb423..9621884 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,7 +88,6 @@ mod error; mod extension; mod layer; mod marker; -mod slot; mod state; mod tx; diff --git a/src/slot.rs b/src/slot.rs deleted file mode 100644 index cc1f392..0000000 --- a/src/slot.rs +++ /dev/null @@ -1,208 +0,0 @@ -//! An API for transferring ownership of a resource. -//! -//! The `Slot` and `Lease` types implement an API for sharing access to a resource `T`. -//! Conceptually, the `Slot` is the "primary" owner of the value, and access can be leased to one -//! other owner through an associated `Lease`. -//! -//! It's implemented as a wrapper around an `Arc>>`, where the `Slot` `take`s the -//! value from the `Option` on lease, and the `Lease` puts it back in on drop. -//! -//! Note that while this is **safe** to use across threads (it is `Send` + `Sync`), concurrently -//! `lease`ing and dropping `Lease`s is not supported. It's intended for use in synchronous -//! scenarios that the compiler thinks may be concurrent, like middleware stacks! - -use parking_lot::Mutex; -use std::sync::{Arc, Weak}; - -/// A slot that may contain a value that can be leased. -/// -/// The slot itself is opaque, the only way to see the value (if any) is with `lease` or -/// `into_inner`. -pub(crate) struct Slot(Arc>>); - -impl Slot { - /// Construct a new `Slot` holding the given value. - pub(crate) fn new(value: T) -> Self { - Self(Arc::new(Mutex::new(Some(value)))) - } - - /// Lease the value from the slot, leaving it empty. - /// - /// Ownership of the contained value moves to the `Lease` for the duration. The value may return - /// to the slot when the `Lease` is dropped, or the value may be "stolen", leaving the slot - /// permanently empty. - pub(crate) fn lease(&self) -> Option> { - if let Some(value) = self.0.try_lock().and_then(|mut slot| slot.take()) { - Some(Lease::new(value, Arc::downgrade(&self.0))) - } else { - None - } - } -} - -impl Clone for Slot { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -/// A lease of a value from a `Slot`. -#[derive(Debug)] -pub(crate) struct Lease(lease::State); - -impl Lease { - fn new(value: T, slot: Weak>>) -> Self { - Self(lease::State::new(value, slot)) - } - - /// Steal the value, meaning it will never return to the slot. - pub(crate) fn steal(mut self) -> T { - self.0.steal() - } -} - -impl Drop for Lease { - fn drop(&mut self) { - self.0.drop() - } -} - -impl AsRef for Lease { - fn as_ref(&self) -> &T { - self.0.as_ref() - } -} - -impl AsMut for Lease { - fn as_mut(&mut self) -> &mut T { - self.0.as_mut() - } -} - -impl std::ops::Deref for Lease { - type Target = T; - - fn deref(&self) -> &Self::Target { - self.0.as_ref() - } -} - -impl std::ops::DerefMut for Lease { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.as_mut() - } -} - -mod lease { - use std::sync::Weak; - - use parking_lot::Mutex; - - #[derive(Debug)] - pub(super) struct State(Inner); - - #[derive(Debug)] - enum Inner { - Dropped, - Stolen, - Live { - value: T, - slot: Weak>>, - }, - } - - impl State { - pub(super) fn new(value: T, slot: Weak>>) -> Self { - Self(Inner::Live { value, slot }) - } - - pub(super) fn as_ref(&self) -> &T { - match &self.0 { - Inner::Dropped | Inner::Stolen => panic!("BUG: LeaseState used after drop/steal"), - Inner::Live { value, .. } => value, - } - } - - pub(super) fn as_mut(&mut self) -> &mut T { - match &mut self.0 { - Inner::Dropped | Inner::Stolen => panic!("BUG: LeaseState used after drop/steal"), - Inner::Live { value, .. } => value, - } - } - - pub(super) fn drop(&mut self) { - match std::mem::replace(&mut self.0, Inner::Dropped) { - Inner::Dropped => panic!("BUG: LeaseState::drop called twice"), - Inner::Stolen => {} // nothing to do if the value was stolen - Inner::Live { value, slot } => { - // try to return value to the slot, if it fails just drop value - if let Some(slot) = slot.upgrade() { - if let Some(mut slot) = slot.try_lock() { - assert!(slot.is_none(), "BUG: slot repopulated during lease"); - *slot = Some(value); - } - } - } - } - } - - pub(super) fn steal(&mut self) -> T { - match std::mem::replace(&mut self.0, Inner::Stolen) { - Inner::Dropped => panic!("BUG: LeaseState::steal called after drop"), - Inner::Stolen => panic!("BUG: LeaseState::steal called twice"), - Inner::Live { value, .. } => value, - } - } - } -} - -#[cfg(test)] -mod tests { - use super::Slot; - - #[test] - fn lease_and_return() { - // Create a slot containing a resource. - let slot = Slot::new("Hello".to_string()); - - // Lease the resource, taking it from the slot. - let mut lease = slot.lease().unwrap(); - - std::thread::spawn(move || { - // We have exclusive access to the resource through the lease, which implements `Deref[Mut]` - lease.push_str(", world!"); - - // By default the value is returned to the slot on drop - }) - .join() - .unwrap(); - - // The value is now back in the slot - assert_eq!( - slot.lease().as_deref().map(|s| s.as_str()), - Some("Hello, world!") - ); - - // We can also take ownership of the value in the slot (if any) - assert_eq!( - slot.lease().map(|lease| lease.steal()), - Some("Hello, world!".to_string()) - ); - } - - #[test] - fn lease_and_steal() { - let slot = Slot::new("Hello".to_string()); - - let lease = slot.lease().unwrap(); - std::thread::spawn(move || { - // We can steal ownership of the resource, leaving the slot permanently empty - let _: String = lease.steal(); - }) - .join() - .unwrap(); - - // The slot is now permanently empty - assert!(slot.lease().is_none()); - } -} diff --git a/src/tx.rs b/src/tx.rs index 1ddf963..e4a1cfa 100644 --- a/src/tx.rs +++ b/src/tx.rs @@ -8,10 +8,10 @@ use axum_core::{ }; use futures_core::{future::BoxFuture, stream::BoxStream}; use http::request::Parts; +use parking_lot::{lock_api::ArcMutexGuard, RawMutex}; use crate::{ extension::{Extension, LazyTransaction}, - slot::Lease, Config, Error, Marker, State, }; @@ -74,7 +74,7 @@ use crate::{ /// } /// ``` pub struct Tx { - tx: Lease>, + tx: ArcMutexGuard>, _error: PhantomData, } @@ -121,8 +121,8 @@ impl Tx { /// /// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently /// generate [`Error::OverlappingExtractors`] errors. This may change in future. - pub async fn commit(self) -> Result<(), sqlx::Error> { - self.tx.steal().commit().await + pub async fn commit(mut self) -> Result<(), sqlx::Error> { + self.tx.commit().await } } @@ -134,13 +134,13 @@ impl fmt::Debug for Tx { impl AsRef> for Tx { fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> { - self.tx.as_ref().as_ref() + self.tx.as_ref() } } impl AsMut> for Tx { fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> { - self.tx.as_mut().as_mut() + self.tx.as_mut() } } @@ -148,13 +148,13 @@ impl std::ops::Deref for Tx { type Target = sqlx::Transaction<'static, DB::Driver>; fn deref(&self) -> &Self::Target { - self.tx.as_ref().as_ref() + self.tx.as_ref() } } impl std::ops::DerefMut for Tx { fn deref_mut(&mut self) -> &mut Self::Target { - self.tx.as_mut().as_mut() + self.tx.as_mut() } }