Skip to content

Commit

Permalink
refactor: ditch slot
Browse files Browse the repository at this point in the history
It turns out the `parking_lot` `arc_lock` future does the main thing we
want from `Slot`, giving us `'static` lock guards that unlock the value
on drop. The only missing functionality is the "stealing" we need to
obtain ownership in order to commit the transaction. Rather than
implementing this via `Option`, we add an additional state to
`LazyTransaction` and handle it there.

Ultimately this removes a lot of code, and makes the synchronisation
mechanism even less exotic.
  • Loading branch information
connec committed Dec 23, 2023
1 parent 84a49ae commit 1cec0ae
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 234 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ bytes = "1"
futures-core = "0.3"
http = "1"
http-body = "1"
parking_lot = "0.12"
parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] }
sqlx = { version = "0.7", default-features = false }
thiserror = "1"
tower-layer = "0.3"
Expand Down
42 changes: 26 additions & 16 deletions src/extension.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
use std::sync::Arc;

use parking_lot::{lock_api::ArcMutexGuard, Mutex, RawMutex};
use sqlx::Transaction;

use crate::{
slot::{Lease, Slot},
Error, Marker, State,
};
use crate::{Error, Marker, State};

/// The request extension.
pub(crate) struct Extension<DB: Marker> {
slot: Slot<LazyTransaction<DB>>,
slot: Arc<Mutex<LazyTransaction<DB>>>,
}

impl<DB: Marker> Extension<DB> {
pub(crate) fn new(state: State<DB>) -> Self {
let slot = Slot::new(LazyTransaction::new(state));
let slot = Arc::new(Mutex::new(LazyTransaction::new(state)));
Self { slot }
}

pub(crate) async fn acquire(&self) -> Result<Lease<LazyTransaction<DB>>, Error> {
let mut tx = self.slot.lease().ok_or(Error::OverlappingExtractors)?;
pub(crate) async fn acquire(
&self,
) -> Result<ArcMutexGuard<RawMutex, LazyTransaction<DB>>, Error> {
let mut tx = self
.slot
.try_lock_arc()
.ok_or(Error::OverlappingExtractors)?;
tx.acquire().await?;

Ok(tx)
}

pub(crate) async fn resolve(&self) -> Result<(), sqlx::Error> {
if let Some(tx) = self.slot.lease() {
tx.steal().resolve().await?;
if let Some(mut tx) = self.slot.try_lock_arc() {
tx.resolve().await?;
}
Ok(())
}
Expand All @@ -49,6 +54,7 @@ enum LazyTransactionState<DB: Marker> {
Acquired {
tx: Transaction<'static, DB::Driver>,
},
Resolved,
}

impl<DB: Marker> LazyTransaction<DB> {
Expand All @@ -62,6 +68,7 @@ impl<DB: Marker> LazyTransaction<DB> {
panic!("BUG: exposed unacquired LazyTransaction")
}
LazyTransactionState::Acquired { tx } => tx,
LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"),
}
}

Expand All @@ -71,33 +78,36 @@ impl<DB: Marker> LazyTransaction<DB> {
panic!("BUG: exposed unacquired LazyTransaction")
}
LazyTransactionState::Acquired { tx } => tx,
LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"),
}
}

async fn acquire(&mut self) -> Result<(), sqlx::Error> {
async fn acquire(&mut self) -> Result<(), Error> {
match &self.0 {
LazyTransactionState::Unacquired { state } => {
let tx = state.transaction().await?;
self.0 = LazyTransactionState::Acquired { tx };
Ok(())
}
LazyTransactionState::Acquired { .. } => Ok(()),
LazyTransactionState::Resolved => Err(Error::OverlappingExtractors),
}
}

pub(crate) async fn resolve(self) -> Result<(), sqlx::Error> {
match self.0 {
LazyTransactionState::Unacquired { .. } => Ok(()),
pub(crate) async fn resolve(&mut self) -> Result<(), sqlx::Error> {
match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) {
LazyTransactionState::Unacquired { .. } | LazyTransactionState::Resolved => Ok(()),
LazyTransactionState::Acquired { tx } => tx.commit().await,
}
}

pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
match self.0 {
pub(crate) async fn commit(&mut self) -> Result<(), sqlx::Error> {
match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: tried to commit unacquired transaction")
}
LazyTransactionState::Acquired { tx } => tx.commit().await,
LazyTransactionState::Resolved => panic!("BUG: tried to commit resolved transaction"),
}
}
}
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ mod error;
mod extension;
mod layer;
mod marker;
mod slot;
mod state;
mod tx;

Expand Down
208 changes: 0 additions & 208 deletions src/slot.rs

This file was deleted.

16 changes: 8 additions & 8 deletions src/tx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use axum_core::{
};
use futures_core::{future::BoxFuture, stream::BoxStream};
use http::request::Parts;
use parking_lot::{lock_api::ArcMutexGuard, RawMutex};

use crate::{
extension::{Extension, LazyTransaction},
slot::Lease,
Config, Error, Marker, State,
};

Expand Down Expand Up @@ -74,7 +74,7 @@ use crate::{
/// }
/// ```
pub struct Tx<DB: Marker, E = Error> {
tx: Lease<LazyTransaction<DB>>,
tx: ArcMutexGuard<RawMutex, LazyTransaction<DB>>,
_error: PhantomData<E>,
}

Expand Down Expand Up @@ -121,8 +121,8 @@ impl<DB: Marker, E> Tx<DB, E> {
///
/// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently
/// generate [`Error::OverlappingExtractors`] errors. This may change in future.
pub async fn commit(self) -> Result<(), sqlx::Error> {
self.tx.steal().commit().await
pub async fn commit(mut self) -> Result<(), sqlx::Error> {
self.tx.commit().await
}
}

Expand All @@ -134,27 +134,27 @@ impl<DB: Marker, E> fmt::Debug for Tx<DB, E> {

impl<DB: Marker, E> AsRef<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> {
self.tx.as_ref().as_ref()
self.tx.as_ref()
}
}

impl<DB: Marker, E> AsMut<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> {
self.tx.as_mut().as_mut()
self.tx.as_mut()
}
}

impl<DB: Marker, E> std::ops::Deref for Tx<DB, E> {
type Target = sqlx::Transaction<'static, DB::Driver>;

fn deref(&self) -> &Self::Target {
self.tx.as_ref().as_ref()
self.tx.as_ref()
}
}

impl<DB: Marker, E> std::ops::DerefMut for Tx<DB, E> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.tx.as_mut().as_mut()
self.tx.as_mut()
}
}

Expand Down

0 comments on commit 1cec0ae

Please sign in to comment.