Skip to content

Commit

Permalink
fix: fix buggy behaviour when cloning request extensions
Browse files Browse the repository at this point in the history
Our naive update for axum@0.7 / hyper@1.0 led to buggy behaviour whereby
the `Tx` extractor would always fail with `OverlappingExtractors` if
there were any outstanding clones of the request extensions (see the new
test). Since request extensions now must all implement `Clone`, it's
possible that some middleware might wish to keep a clone of all request
extensions (e.g. for request inspection/tracing/debugging), rendering
`Tx` unusable with those middleware.

To fix it, we simplify the synchronisation by implementing `Clone` for
`Slot` and creating new `Extension<DB>` and `LazyTransaction<DB>` types
to replace `TxSlot<DB>` and `Lazy<DB>`.

`Slot<T>` is a wrapper around an `Arc<Mutex<Option<T>>>`, and as such it
can trivially implement `Clone` (there was some "pit of success"
considerations with the previous API intended to enforce proper usage,
but that is unnecesarily limiting given the underlying `Mutex`).

The `Extension<DB>` holds a `Slot` containing a `LazyTransaction<DB>`.
`Extension<DB>` is trivially clonable since `Slot` itself is. The
`LazyTransaction<DB>` then implements a simple "lazily acquired
transaction" protocol, making use of normal rust ownership and borrowing
rules to manage the transaction (i.e. it has no internal
synchronisation).

This makes the overall synchronisation picture much simpler: the
middleware future and all clones of the request extension hold a
reference to the same `Slot`. The `Tx` extractor obtains its copy of the
request extension and attempts to `lease` the inner `LazyTransaction`,
failing with `OverlappingExtractors` if the lease is already taken (this
is the only public invocation of `lease`, and so overlapping extractors
can be the only* cause of an absent transaction). If the lease is
successful, the extractor can acquire a transaction (if there's not one
already) and package it up for request handlers to then interact with.

* Technically the transaction can be "stolen" from the `Tx` extractor by
committing explicitly, but this considered to create an endless
"overlap" in the current semantics.

The main caveat of this approach seems to be that the `Tx` extractor no
longer has type-level knowledge that it can access a `Transaction` - the
`Transaction` can only be accessed by matching on the `LazyTransaction`
state. This doesn't affect the API, but it could make bumping into a
panic more likely, or there may be performance implications (though
these would likely be dwarfed by the I/O involved in interacting with a
database).
  • Loading branch information
connec committed Dec 23, 2023
1 parent 9e9827f commit 84a49ae
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 83 deletions.
103 changes: 103 additions & 0 deletions src/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use sqlx::Transaction;

use crate::{
slot::{Lease, Slot},
Error, Marker, State,
};

/// The request extension.
pub(crate) struct Extension<DB: Marker> {
slot: Slot<LazyTransaction<DB>>,
}

impl<DB: Marker> Extension<DB> {
pub(crate) fn new(state: State<DB>) -> Self {
let slot = Slot::new(LazyTransaction::new(state));
Self { slot }
}

pub(crate) async fn acquire(&self) -> Result<Lease<LazyTransaction<DB>>, Error> {
let mut tx = self.slot.lease().ok_or(Error::OverlappingExtractors)?;
tx.acquire().await?;

Ok(tx)
}

pub(crate) async fn resolve(&self) -> Result<(), sqlx::Error> {
if let Some(tx) = self.slot.lease() {
tx.steal().resolve().await?;
}
Ok(())
}
}

impl<DB: Marker> Clone for Extension<DB> {
fn clone(&self) -> Self {
Self {
slot: self.slot.clone(),
}
}
}

/// The lazy transaction.
pub(crate) struct LazyTransaction<DB: Marker>(LazyTransactionState<DB>);

enum LazyTransactionState<DB: Marker> {
Unacquired {
state: State<DB>,
},
Acquired {
tx: Transaction<'static, DB::Driver>,
},
}

impl<DB: Marker> LazyTransaction<DB> {
fn new(state: State<DB>) -> Self {
Self(LazyTransactionState::Unacquired { state })
}

pub(crate) fn as_ref(&self) -> &Transaction<'static, DB::Driver> {
match &self.0 {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: exposed unacquired LazyTransaction")
}
LazyTransactionState::Acquired { tx } => tx,
}
}

pub(crate) fn as_mut(&mut self) -> &mut Transaction<'static, DB::Driver> {
match &mut self.0 {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: exposed unacquired LazyTransaction")
}
LazyTransactionState::Acquired { tx } => tx,
}
}

async fn acquire(&mut self) -> Result<(), sqlx::Error> {
match &self.0 {
LazyTransactionState::Unacquired { state } => {
let tx = state.transaction().await?;
self.0 = LazyTransactionState::Acquired { tx };
Ok(())
}
LazyTransactionState::Acquired { .. } => Ok(()),
}
}

pub(crate) async fn resolve(self) -> Result<(), sqlx::Error> {
match self.0 {
LazyTransactionState::Unacquired { .. } => Ok(()),
LazyTransactionState::Acquired { tx } => tx.commit().await,
}
}

pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
match self.0 {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: tried to commit unacquired transaction")
}
LazyTransactionState::Acquired { tx } => tx.commit().await,
}
}
}
7 changes: 4 additions & 3 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use bytes::Bytes;
use futures_core::future::BoxFuture;
use http_body::Body;

use crate::{tx::TxSlot, Marker, State};
use crate::{extension::Extension, Marker, State};

/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
///
Expand Down Expand Up @@ -109,15 +109,16 @@ where
}

fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
let transaction = TxSlot::bind(req.extensions_mut(), self.state.clone());
let ext = Extension::new(self.state.clone());
req.extensions_mut().insert(ext.clone());

let res = self.inner.call(req);

Box::pin(async move {
let res = res.await.unwrap(); // inner service is infallible

if !res.status().is_server_error() && !res.status().is_client_error() {
if let Err(error) = transaction.commit().await {
if let Err(error) = ext.resolve().await {
return Ok(error.into().into_response());
}
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@

mod config;
mod error;
mod extension;
mod layer;
mod marker;
mod slot;
Expand Down
31 changes: 11 additions & 20 deletions src/slot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,23 @@ impl<T> Slot<T> {
Self(Arc::new(Mutex::new(Some(value))))
}

/// Construct a new `Slot` and an immediately acquired `Lease`.
///
/// This avoids the impossible `None` vs. calling `new()` then `lease()`.
pub(crate) fn new_leased(value: T) -> (Self, Lease<T>) {
let mut slot = Self::new(value);
let lease = slot.lease().expect("BUG: new slot empty");
(slot, lease)
}

/// Lease the value from the slot, leaving it empty.
///
/// Ownership of the contained value moves to the `Lease` for the duration. The value may return
/// to the slot when the `Lease` is dropped, or the value may be "stolen", leaving the slot
/// permanently empty.
pub(crate) fn lease(&mut self) -> Option<Lease<T>> {
pub(crate) fn lease(&self) -> Option<Lease<T>> {
if let Some(value) = self.0.try_lock().and_then(|mut slot| slot.take()) {
Some(Lease::new(value, Arc::downgrade(&self.0)))
} else {
None
}
}
}

/// Get the inner value from the slot, if any.
///
/// Note that if this returns `Some`, there are no oustanding leases. If it returns `None` then
/// the value has been leased, and since this consumes the slot the value will be dropped once
/// the lease is done.
pub(crate) fn into_inner(self) -> Option<T> {
self.0.try_lock().and_then(|mut slot| slot.take())
impl<T> Clone for Slot<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

Expand Down Expand Up @@ -175,7 +163,7 @@ mod tests {
#[test]
fn lease_and_return() {
// Create a slot containing a resource.
let mut slot = Slot::new("Hello".to_string());
let slot = Slot::new("Hello".to_string());

// Lease the resource, taking it from the slot.
let mut lease = slot.lease().unwrap();
Expand All @@ -196,12 +184,15 @@ mod tests {
);

// We can also take ownership of the value in the slot (if any)
assert_eq!(slot.into_inner(), Some("Hello, world!".to_string()));
assert_eq!(
slot.lease().map(|lease| lease.steal()),
Some("Hello, world!".to_string())
);
}

#[test]
fn lease_and_steal() {
let mut slot = Slot::new("Hello".to_string());
let slot = Slot::new("Hello".to_string());

let lease = slot.lease().unwrap();
std::thread::spawn(move || {
Expand Down
76 changes: 16 additions & 60 deletions src/tx.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
//! A request extension that enables the [`Tx`](crate::Tx) extractor.

use std::{marker::PhantomData, sync::Arc};
use std::{fmt, marker::PhantomData};

use axum_core::{
extract::{FromRef, FromRequestParts},
response::IntoResponse,
};
use futures_core::{future::BoxFuture, stream::BoxStream};
use http::request::Parts;
use sqlx::Transaction;

use crate::{
slot::{Lease, Slot},
extension::{Extension, LazyTransaction},
slot::Lease,
Config, Error, Marker, State,
};

Expand Down Expand Up @@ -73,9 +73,8 @@ use crate::{
/// /* ... */
/// }
/// ```
#[derive(Debug)]
pub struct Tx<DB: Marker, E = Error> {
tx: Lease<sqlx::Transaction<'static, DB::Driver>>,
tx: Lease<LazyTransaction<DB>>,
_error: PhantomData<E>,
}

Expand Down Expand Up @@ -127,29 +126,35 @@ impl<DB: Marker, E> Tx<DB, E> {
}
}

impl<DB: Marker, E> fmt::Debug for Tx<DB, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tx").finish_non_exhaustive()
}
}

impl<DB: Marker, E> AsRef<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> {
&self.tx
self.tx.as_ref().as_ref()
}
}

impl<DB: Marker, E> AsMut<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> {
&mut self.tx
self.tx.as_mut().as_mut()
}
}

impl<DB: Marker, E> std::ops::Deref for Tx<DB, E> {
type Target = sqlx::Transaction<'static, DB::Driver>;

fn deref(&self) -> &Self::Target {
&self.tx
self.tx.as_ref().as_ref()
}
}

impl<DB: Marker, E> std::ops::DerefMut for Tx<DB, E> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.tx
self.tx.as_mut().as_mut()
}
}

Expand All @@ -170,13 +175,9 @@ where
'state: 'ctx,
{
Box::pin(async move {
let ext: &mut Arc<Lazy<DB>> =
parts.extensions.get_mut().ok_or(Error::MissingExtension)?;
let ext: &Extension<DB> = parts.extensions.get().ok_or(Error::MissingExtension)?;

let tx = Arc::get_mut(ext)
.ok_or(Error::OverlappingExtractors)?
.get_or_begin()
.await?;
let tx = ext.acquire().await?;

Ok(Self {
tx,
Expand All @@ -186,51 +187,6 @@ where
}
}

/// The OG `Slot` – the transaction (if any) returns here when the `Extension` is dropped.
pub(crate) struct TxSlot<DB: Marker>(Slot<Option<Slot<Transaction<'static, DB::Driver>>>>);

impl<DB: Marker> TxSlot<DB> {
/// Create a `TxSlot` bound to the given request extensions.
///
/// When the request extensions are dropped, `commit` can be called to commit the transaction
/// (if any).
pub(crate) fn bind(extensions: &mut http::Extensions, state: State<DB>) -> Self {
let (slot, tx) = Slot::new_leased(None);
let lazy = Arc::new(Lazy { state, tx });
extensions.insert(lazy);
Self(slot)
}

pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
if let Some(tx) = self.0.into_inner().flatten().and_then(Slot::into_inner) {
tx.commit().await?;
}
Ok(())
}
}

/// A lazily acquired transaction.
///
/// When the transaction is started, it's inserted into the `Option` leased from the `TxSlot`, so
/// that when `Lazy` is dropped the transaction is moved to the `TxSlot`.
struct Lazy<DB: Marker> {
state: State<DB>,
tx: Lease<Option<Slot<Transaction<'static, DB::Driver>>>>,
}

impl<DB: Marker> Lazy<DB> {
async fn get_or_begin(&mut self) -> Result<Lease<Transaction<'static, DB::Driver>>, Error> {
let tx = if let Some(tx) = self.tx.as_mut() {
tx
} else {
let tx = self.state.transaction().await?;
self.tx.insert(Slot::new(tx))
};

tx.lease().ok_or(Error::OverlappingExtractors)
}
}

impl<'c, DB, E> sqlx::Executor<'c> for &'c mut Tx<DB, E>
where
DB: Marker,
Expand Down
41 changes: 41 additions & 0 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,47 @@ async fn extract_from_middleware_and_handler() {
assert_eq!(body.as_ref(), b"[[1,\"bobby tables\"]]");
}

#[tokio::test]
async fn middleware_cloning_request_extensions() {
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();

async fn test_middleware(
req: http::Request<axum::body::Body>,
next: middleware::Next,
) -> impl IntoResponse {
// Hold a clone of the request extensions
let _extensions = req.extensions().clone();

next.run(req).await
}

let (state, layer) = Tx::setup(pool);

let app = axum::Router::new()
.route("/", axum::routing::get(|_tx: Tx| async move {}))
.layer(middleware::from_fn_with_state(
state.clone(),
test_middleware,
))
.layer(layer)
.with_state(state);

let response = app
.oneshot(
http::Request::builder()
.uri("/")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
let status = response.status();
let body = response.into_body().collect().await.unwrap().to_bytes();
dbg!(body);

assert!(status.is_success());
}

#[tokio::test]
async fn substates() {
#[derive(Clone)]
Expand Down

0 comments on commit 84a49ae

Please sign in to comment.