diff --git a/Cargo.toml b/Cargo.toml index fb61624..216867b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,12 +26,12 @@ runtime-tokio-rustls = ["sqlx/runtime-tokio-rustls"] features = ["all-databases", "runtime-tokio-rustls"] [dependencies] -axum-core = "0.3" +axum-core = "0.4" bytes = "1" futures-core = "0.3" -http = "0.2" -http-body = "0.4" -parking_lot = "0.12" +http = "1" +http-body = "1" +parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] } sqlx = { version = "0.7", default-features = false } thiserror = "1" tower-layer = "0.3" @@ -39,7 +39,8 @@ tower-service = "0.3" [dev-dependencies] axum-sqlx-tx = { path = ".", features = ["runtime-tokio-rustls", "sqlite"] } -axum = "0.6.4" -hyper = "0.14.17" +axum = "0.7.1" +http-body-util = "0.1.0" +hyper = "1.0.1" tokio = { version = "1.17.0", features = ["macros", "rt-multi-thread"] } tower = "0.4.12" diff --git a/examples/example.rs b/examples/example.rs index 4c09e05..0be4119 100644 --- a/examples/example.rs +++ b/examples/example.rs @@ -28,10 +28,10 @@ async fn main() -> Result<(), Box> { // Add the Tx state .with_state(state); - let server = axum::Server::bind(&([0, 0, 0, 0], 0).into()).serve(app.into_make_service()); + let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + println!("Listening on {}", listener.local_addr().unwrap()); - println!("Listening on {}", server.local_addr()); - server.await?; + axum::serve(listener, app).await?; Ok(()) } diff --git a/src/error.rs b/src/error.rs index 014ea2c..e42fe34 100644 --- a/src/error.rs +++ b/src/error.rs @@ -70,7 +70,7 @@ /// # .route("/", post(create_user)) /// # .layer(layer) /// # .with_state(state); -/// # axum::Server::bind(todo!()).serve(app.into_make_service()); +/// # axum::serve(todo!(), app); /// # } /// # async fn create_user(mut tx: Tx, /* ... */) { /// # /* ... */ diff --git a/src/extension.rs b/src/extension.rs new file mode 100644 index 0000000..196575d --- /dev/null +++ b/src/extension.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +use parking_lot::{lock_api::ArcMutexGuard, Mutex, RawMutex}; +use sqlx::Transaction; + +use crate::{Error, Marker, State}; + +/// The request extension. +pub(crate) struct Extension { + slot: Arc>>, +} + +impl Extension { + pub(crate) fn new(state: State) -> Self { + let slot = Arc::new(Mutex::new(LazyTransaction::new(state))); + Self { slot } + } + + pub(crate) async fn acquire( + &self, + ) -> Result>, Error> { + let mut tx = self + .slot + .try_lock_arc() + .ok_or(Error::OverlappingExtractors)?; + tx.acquire().await?; + + Ok(tx) + } + + pub(crate) async fn resolve(&self) -> Result<(), sqlx::Error> { + if let Some(mut tx) = self.slot.try_lock_arc() { + tx.resolve().await?; + } + Ok(()) + } +} + +impl Clone for Extension { + fn clone(&self) -> Self { + Self { + slot: self.slot.clone(), + } + } +} + +/// The lazy transaction. +pub(crate) struct LazyTransaction(LazyTransactionState); + +enum LazyTransactionState { + Unacquired { + state: State, + }, + Acquired { + tx: Transaction<'static, DB::Driver>, + }, + Resolved, +} + +impl LazyTransaction { + fn new(state: State) -> Self { + Self(LazyTransactionState::Unacquired { state }) + } + + pub(crate) fn as_ref(&self) -> &Transaction<'static, DB::Driver> { + match &self.0 { + LazyTransactionState::Unacquired { .. } => { + panic!("BUG: exposed unacquired LazyTransaction") + } + LazyTransactionState::Acquired { tx } => tx, + LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"), + } + } + + pub(crate) fn as_mut(&mut self) -> &mut Transaction<'static, DB::Driver> { + match &mut self.0 { + LazyTransactionState::Unacquired { .. } => { + panic!("BUG: exposed unacquired LazyTransaction") + } + LazyTransactionState::Acquired { tx } => tx, + LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"), + } + } + + async fn acquire(&mut self) -> Result<(), Error> { + match &self.0 { + LazyTransactionState::Unacquired { state } => { + let tx = state.transaction().await?; + self.0 = LazyTransactionState::Acquired { tx }; + Ok(()) + } + LazyTransactionState::Acquired { .. } => Ok(()), + LazyTransactionState::Resolved => Err(Error::OverlappingExtractors), + } + } + + pub(crate) async fn resolve(&mut self) -> Result<(), sqlx::Error> { + match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) { + LazyTransactionState::Unacquired { .. } | LazyTransactionState::Resolved => Ok(()), + LazyTransactionState::Acquired { tx } => tx.commit().await, + } + } + + pub(crate) async fn commit(&mut self) -> Result<(), sqlx::Error> { + match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) { + LazyTransactionState::Unacquired { .. } => { + panic!("BUG: tried to commit unacquired transaction") + } + LazyTransactionState::Acquired { tx } => tx.commit().await, + LazyTransactionState::Resolved => panic!("BUG: tried to commit resolved transaction"), + } + } +} diff --git a/src/layer.rs b/src/layer.rs index b5b4527..931d20e 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -5,9 +5,9 @@ use std::marker::PhantomData; use axum_core::response::IntoResponse; use bytes::Bytes; use futures_core::future::BoxFuture; -use http_body::{combinators::UnsyncBoxBody, Body}; +use http_body::Body; -use crate::{tx::TxSlot, Marker, State}; +use crate::{extension::Extension, Marker, State}; /// A [`tower_layer::Layer`] that enables the [`Tx`] extractor. /// @@ -97,7 +97,7 @@ where ResBody: Body + Send + 'static, ResBody::Error: Into>, { - type Response = http::Response>; + type Response = http::Response; type Error = S::Error; type Future = BoxFuture<'static, Result>; @@ -109,7 +109,8 @@ where } fn call(&mut self, mut req: http::Request) -> Self::Future { - let transaction = TxSlot::bind(req.extensions_mut(), self.state.clone()); + let ext = Extension::new(self.state.clone()); + req.extensions_mut().insert(ext.clone()); let res = self.inner.call(req); @@ -117,12 +118,12 @@ where let res = res.await.unwrap(); // inner service is infallible if !res.status().is_server_error() && !res.status().is_client_error() { - if let Err(error) = transaction.commit().await { + if let Err(error) = ext.resolve().await { return Ok(error.into().into_response()); } } - Ok(res.map(|body| body.map_err(axum_core::Error::new).boxed_unsync())) + Ok(res.map(axum_core::body::Body::new)) }) } } @@ -145,6 +146,6 @@ mod tests { .route("/", axum::routing::get(|| async { "hello" })) .layer(layer); - axum::Server::bind(todo!()).serve(app.into_make_service()); + axum::serve(todo!(), app); } } diff --git a/src/lib.rs b/src/lib.rs index 49607b9..9621884 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,7 @@ //! # .route("/", axum::routing::get(|tx: Tx| async move {})) //! .layer(layer) //! .with_state(state); -//! # axum::Server::bind(todo!()).serve(app.into_make_service()); +//! # axum::serve(todo!(), app); //! # } //! ``` //! @@ -85,9 +85,9 @@ mod config; mod error; +mod extension; mod layer; mod marker; -mod slot; mod state; mod tx; diff --git a/src/marker.rs b/src/marker.rs index 11eccbd..55e7c6f 100644 --- a/src/marker.rs +++ b/src/marker.rs @@ -63,7 +63,7 @@ use std::fmt::Debug; /// .layer(layer1) /// .layer(layer2) /// .with_state(MyState { state1, state2 }); -/// # axum::Server::bind(todo!()).serve(app.into_make_service()); +/// # axum::serve(todo!(), app); /// # } /// ``` pub trait Marker: Debug + Send + Sized + 'static { diff --git a/src/slot.rs b/src/slot.rs deleted file mode 100644 index 2102ef9..0000000 --- a/src/slot.rs +++ /dev/null @@ -1,217 +0,0 @@ -//! An API for transferring ownership of a resource. -//! -//! The `Slot` and `Lease` types implement an API for sharing access to a resource `T`. -//! Conceptually, the `Slot` is the "primary" owner of the value, and access can be leased to one -//! other owner through an associated `Lease`. -//! -//! It's implemented as a wrapper around an `Arc>>`, where the `Slot` `take`s the -//! value from the `Option` on lease, and the `Lease` puts it back in on drop. -//! -//! Note that while this is **safe** to use across threads (it is `Send` + `Sync`), concurrently -//! `lease`ing and dropping `Lease`s is not supported. It's intended for use in synchronous -//! scenarios that the compiler thinks may be concurrent, like middleware stacks! - -use parking_lot::Mutex; -use std::sync::{Arc, Weak}; - -/// A slot that may contain a value that can be leased. -/// -/// The slot itself is opaque, the only way to see the value (if any) is with `lease` or -/// `into_inner`. -pub(crate) struct Slot(Arc>>); - -impl Slot { - /// Construct a new `Slot` holding the given value. - pub(crate) fn new(value: T) -> Self { - Self(Arc::new(Mutex::new(Some(value)))) - } - - /// Construct a new `Slot` and an immediately acquired `Lease`. - /// - /// This avoids the impossible `None` vs. calling `new()` then `lease()`. - pub(crate) fn new_leased(value: T) -> (Self, Lease) { - let mut slot = Self::new(value); - let lease = slot.lease().expect("BUG: new slot empty"); - (slot, lease) - } - - /// Lease the value from the slot, leaving it empty. - /// - /// Ownership of the contained value moves to the `Lease` for the duration. The value may return - /// to the slot when the `Lease` is dropped, or the value may be "stolen", leaving the slot - /// permanently empty. - pub(crate) fn lease(&mut self) -> Option> { - if let Some(value) = self.0.try_lock().and_then(|mut slot| slot.take()) { - Some(Lease::new(value, Arc::downgrade(&self.0))) - } else { - None - } - } - - /// Get the inner value from the slot, if any. - /// - /// Note that if this returns `Some`, there are no oustanding leases. If it returns `None` then - /// the value has been leased, and since this consumes the slot the value will be dropped once - /// the lease is done. - pub(crate) fn into_inner(self) -> Option { - self.0.try_lock().and_then(|mut slot| slot.take()) - } -} - -/// A lease of a value from a `Slot`. -#[derive(Debug)] -pub(crate) struct Lease(lease::State); - -impl Lease { - fn new(value: T, slot: Weak>>) -> Self { - Self(lease::State::new(value, slot)) - } - - /// Steal the value, meaning it will never return to the slot. - pub(crate) fn steal(mut self) -> T { - self.0.steal() - } -} - -impl Drop for Lease { - fn drop(&mut self) { - self.0.drop() - } -} - -impl AsRef for Lease { - fn as_ref(&self) -> &T { - self.0.as_ref() - } -} - -impl AsMut for Lease { - fn as_mut(&mut self) -> &mut T { - self.0.as_mut() - } -} - -impl std::ops::Deref for Lease { - type Target = T; - - fn deref(&self) -> &Self::Target { - self.0.as_ref() - } -} - -impl std::ops::DerefMut for Lease { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.as_mut() - } -} - -mod lease { - use std::sync::Weak; - - use parking_lot::Mutex; - - #[derive(Debug)] - pub(super) struct State(Inner); - - #[derive(Debug)] - enum Inner { - Dropped, - Stolen, - Live { - value: T, - slot: Weak>>, - }, - } - - impl State { - pub(super) fn new(value: T, slot: Weak>>) -> Self { - Self(Inner::Live { value, slot }) - } - - pub(super) fn as_ref(&self) -> &T { - match &self.0 { - Inner::Dropped | Inner::Stolen => panic!("BUG: LeaseState used after drop/steal"), - Inner::Live { value, .. } => value, - } - } - - pub(super) fn as_mut(&mut self) -> &mut T { - match &mut self.0 { - Inner::Dropped | Inner::Stolen => panic!("BUG: LeaseState used after drop/steal"), - Inner::Live { value, .. } => value, - } - } - - pub(super) fn drop(&mut self) { - match std::mem::replace(&mut self.0, Inner::Dropped) { - Inner::Dropped => panic!("BUG: LeaseState::drop called twice"), - Inner::Stolen => {} // nothing to do if the value was stolen - Inner::Live { value, slot } => { - // try to return value to the slot, if it fails just drop value - if let Some(slot) = slot.upgrade() { - if let Some(mut slot) = slot.try_lock() { - assert!(slot.is_none(), "BUG: slot repopulated during lease"); - *slot = Some(value); - } - } - } - } - } - - pub(super) fn steal(&mut self) -> T { - match std::mem::replace(&mut self.0, Inner::Stolen) { - Inner::Dropped => panic!("BUG: LeaseState::steal called after drop"), - Inner::Stolen => panic!("BUG: LeaseState::steal called twice"), - Inner::Live { value, .. } => value, - } - } - } -} - -#[cfg(test)] -mod tests { - use super::Slot; - - #[test] - fn lease_and_return() { - // Create a slot containing a resource. - let mut slot = Slot::new("Hello".to_string()); - - // Lease the resource, taking it from the slot. - let mut lease = slot.lease().unwrap(); - - std::thread::spawn(move || { - // We have exclusive access to the resource through the lease, which implements `Deref[Mut]` - lease.push_str(", world!"); - - // By default the value is returned to the slot on drop - }) - .join() - .unwrap(); - - // The value is now back in the slot - assert_eq!( - slot.lease().as_deref().map(|s| s.as_str()), - Some("Hello, world!") - ); - - // We can also take ownership of the value in the slot (if any) - assert_eq!(slot.into_inner(), Some("Hello, world!".to_string())); - } - - #[test] - fn lease_and_steal() { - let mut slot = Slot::new("Hello".to_string()); - - let lease = slot.lease().unwrap(); - std::thread::spawn(move || { - // We can steal ownership of the resource, leaving the slot permanently empty - let _: String = lease.steal(); - }) - .join() - .unwrap(); - - // The slot is now permanently empty - assert!(slot.lease().is_none()); - } -} diff --git a/src/tx.rs b/src/tx.rs index a013a90..e4a1cfa 100644 --- a/src/tx.rs +++ b/src/tx.rs @@ -1,6 +1,6 @@ //! A request extension that enables the [`Tx`](crate::Tx) extractor. -use std::marker::PhantomData; +use std::{fmt, marker::PhantomData}; use axum_core::{ extract::{FromRef, FromRequestParts}, @@ -8,10 +8,10 @@ use axum_core::{ }; use futures_core::{future::BoxFuture, stream::BoxStream}; use http::request::Parts; -use sqlx::Transaction; +use parking_lot::{lock_api::ArcMutexGuard, RawMutex}; use crate::{ - slot::{Lease, Slot}, + extension::{Extension, LazyTransaction}, Config, Error, Marker, State, }; @@ -73,9 +73,8 @@ use crate::{ /// /* ... */ /// } /// ``` -#[derive(Debug)] pub struct Tx { - tx: Lease>, + tx: ArcMutexGuard>, _error: PhantomData, } @@ -122,20 +121,26 @@ impl Tx { /// /// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently /// generate [`Error::OverlappingExtractors`] errors. This may change in future. - pub async fn commit(self) -> Result<(), sqlx::Error> { - self.tx.steal().commit().await + pub async fn commit(mut self) -> Result<(), sqlx::Error> { + self.tx.commit().await + } +} + +impl fmt::Debug for Tx { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Tx").finish_non_exhaustive() } } impl AsRef> for Tx { fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> { - &self.tx + self.tx.as_ref() } } impl AsMut> for Tx { fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> { - &mut self.tx + self.tx.as_mut() } } @@ -143,13 +148,13 @@ impl std::ops::Deref for Tx { type Target = sqlx::Transaction<'static, DB::Driver>; fn deref(&self) -> &Self::Target { - &self.tx + self.tx.as_ref() } } impl std::ops::DerefMut for Tx { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.tx + self.tx.as_mut() } } @@ -170,9 +175,9 @@ where 'state: 'ctx, { Box::pin(async move { - let ext: &mut Lazy = parts.extensions.get_mut().ok_or(Error::MissingExtension)?; + let ext: &Extension = parts.extensions.get().ok_or(Error::MissingExtension)?; - let tx = ext.get_or_begin().await?; + let tx = ext.acquire().await?; Ok(Self { tx, @@ -182,50 +187,6 @@ where } } -/// The OG `Slot` – the transaction (if any) returns here when the `Extension` is dropped. -pub(crate) struct TxSlot(Slot>>>); - -impl TxSlot { - /// Create a `TxSlot` bound to the given request extensions. - /// - /// When the request extensions are dropped, `commit` can be called to commit the transaction - /// (if any). - pub(crate) fn bind(extensions: &mut http::Extensions, state: State) -> Self { - let (slot, tx) = Slot::new_leased(None); - extensions.insert(Lazy { state, tx }); - Self(slot) - } - - pub(crate) async fn commit(self) -> Result<(), sqlx::Error> { - if let Some(tx) = self.0.into_inner().flatten().and_then(Slot::into_inner) { - tx.commit().await?; - } - Ok(()) - } -} - -/// A lazily acquired transaction. -/// -/// When the transaction is started, it's inserted into the `Option` leased from the `TxSlot`, so -/// that when `Lazy` is dropped the transaction is moved to the `TxSlot`. -struct Lazy { - state: State, - tx: Lease>>>, -} - -impl Lazy { - async fn get_or_begin(&mut self) -> Result>, Error> { - let tx = if let Some(tx) = self.tx.as_mut() { - tx - } else { - let tx = self.state.transaction().await?; - self.tx.insert(Slot::new(tx)) - }; - - tx.lease().ok_or(Error::OverlappingExtractors) - } -} - impl<'c, DB, E> sqlx::Executor<'c> for &'c mut Tx where DB: Marker, diff --git a/tests/lib.rs b/tests/lib.rs index 7351dfb..868b76d 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1,5 +1,6 @@ use axum::{middleware, response::IntoResponse}; use axum_sqlx_tx::State; +use http_body_util::BodyExt; use sqlx::{sqlite::SqliteArguments, Arguments as _}; use tower::ServiceExt; @@ -81,10 +82,10 @@ async fn extract_from_middleware_and_handler() { .await .unwrap(); - async fn test_middleware( + async fn test_middleware( mut tx: Tx, - req: http::Request, - next: middleware::Next, + req: http::Request, + next: middleware::Next, ) -> impl IntoResponse { insert_user(&mut tx, 1, "bobby tables").await; @@ -123,12 +124,53 @@ async fn extract_from_middleware_and_handler() { .await .unwrap(); let status = response.status(); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.into_body().collect().await.unwrap().to_bytes(); assert!(status.is_success()); assert_eq!(body.as_ref(), b"[[1,\"bobby tables\"]]"); } +#[tokio::test] +async fn middleware_cloning_request_extensions() { + let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); + + async fn test_middleware( + req: http::Request, + next: middleware::Next, + ) -> impl IntoResponse { + // Hold a clone of the request extensions + let _extensions = req.extensions().clone(); + + next.run(req).await + } + + let (state, layer) = Tx::setup(pool); + + let app = axum::Router::new() + .route("/", axum::routing::get(|_tx: Tx| async move {})) + .layer(middleware::from_fn_with_state( + state.clone(), + test_middleware, + )) + .layer(layer) + .with_state(state); + + let response = app + .oneshot( + http::Request::builder() + .uri("/") + .body(axum::body::Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + let status = response.status(); + let body = response.into_body().collect().await.unwrap().to_bytes(); + dbg!(body); + + assert!(status.is_success()); +} + #[tokio::test] async fn substates() { #[derive(Clone)] @@ -185,7 +227,7 @@ async fn missing_layer() { assert!(response.status().is_server_error()); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body, format!("{}", axum_sqlx_tx::Error::MissingExtension)); } @@ -254,7 +296,7 @@ async fn layer_error_override() { .await .unwrap(); let status = response.status(); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.into_body().collect().await.unwrap().to_bytes(); assert!(status.is_client_error()); assert_eq!(body, "internal server error"); @@ -374,7 +416,7 @@ struct Response { async fn build_app(handler: H) -> (sqlx::SqlitePool, Response) where - H: axum::handler::Handler, axum::body::Body>, + H: axum::handler::Handler>, T: 'static, { let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); @@ -401,7 +443,7 @@ where .await .unwrap(); let status = response.status(); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = response.into_body().collect().await.unwrap().to_bytes(); (pool, Response { status, body }) }