Skip to content

Commit

Permalink
Merge pull request #34 from digital-society-coop/axum-0.7
Browse files Browse the repository at this point in the history
Upgrade axum to 0.7
  • Loading branch information
connec committed Dec 23, 2023
2 parents 2b5a3c0 + 1cec0ae commit 9a2f761
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 302 deletions.
13 changes: 7 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,21 @@ runtime-tokio-rustls = ["sqlx/runtime-tokio-rustls"]
features = ["all-databases", "runtime-tokio-rustls"]

[dependencies]
axum-core = "0.3"
axum-core = "0.4"
bytes = "1"
futures-core = "0.3"
http = "0.2"
http-body = "0.4"
parking_lot = "0.12"
http = "1"
http-body = "1"
parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] }
sqlx = { version = "0.7", default-features = false }
thiserror = "1"
tower-layer = "0.3"
tower-service = "0.3"

[dev-dependencies]
axum-sqlx-tx = { path = ".", features = ["runtime-tokio-rustls", "sqlite"] }
axum = "0.6.4"
hyper = "0.14.17"
axum = "0.7.1"
http-body-util = "0.1.0"
hyper = "1.0.1"
tokio = { version = "1.17.0", features = ["macros", "rt-multi-thread"] }
tower = "0.4.12"
6 changes: 3 additions & 3 deletions examples/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Add the Tx state
.with_state(state);

let server = axum::Server::bind(&([0, 0, 0, 0], 0).into()).serve(app.into_make_service());
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
println!("Listening on {}", listener.local_addr().unwrap());

println!("Listening on {}", server.local_addr());
server.await?;
axum::serve(listener, app).await?;

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
/// # .route("/", post(create_user))
/// # .layer(layer)
/// # .with_state(state);
/// # axum::Server::bind(todo!()).serve(app.into_make_service());
/// # axum::serve(todo!(), app);
/// # }
/// # async fn create_user(mut tx: Tx, /* ... */) {
/// # /* ... */
Expand Down
113 changes: 113 additions & 0 deletions src/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::sync::Arc;

use parking_lot::{lock_api::ArcMutexGuard, Mutex, RawMutex};
use sqlx::Transaction;

use crate::{Error, Marker, State};

/// The request extension.
pub(crate) struct Extension<DB: Marker> {
slot: Arc<Mutex<LazyTransaction<DB>>>,
}

impl<DB: Marker> Extension<DB> {
pub(crate) fn new(state: State<DB>) -> Self {
let slot = Arc::new(Mutex::new(LazyTransaction::new(state)));
Self { slot }
}

pub(crate) async fn acquire(
&self,
) -> Result<ArcMutexGuard<RawMutex, LazyTransaction<DB>>, Error> {
let mut tx = self
.slot
.try_lock_arc()
.ok_or(Error::OverlappingExtractors)?;
tx.acquire().await?;

Ok(tx)
}

pub(crate) async fn resolve(&self) -> Result<(), sqlx::Error> {
if let Some(mut tx) = self.slot.try_lock_arc() {
tx.resolve().await?;
}
Ok(())
}
}

impl<DB: Marker> Clone for Extension<DB> {
fn clone(&self) -> Self {
Self {
slot: self.slot.clone(),
}
}
}

/// The lazy transaction.
pub(crate) struct LazyTransaction<DB: Marker>(LazyTransactionState<DB>);

enum LazyTransactionState<DB: Marker> {
Unacquired {
state: State<DB>,
},
Acquired {
tx: Transaction<'static, DB::Driver>,
},
Resolved,
}

impl<DB: Marker> LazyTransaction<DB> {
fn new(state: State<DB>) -> Self {
Self(LazyTransactionState::Unacquired { state })
}

pub(crate) fn as_ref(&self) -> &Transaction<'static, DB::Driver> {
match &self.0 {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: exposed unacquired LazyTransaction")
}
LazyTransactionState::Acquired { tx } => tx,
LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"),
}
}

pub(crate) fn as_mut(&mut self) -> &mut Transaction<'static, DB::Driver> {
match &mut self.0 {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: exposed unacquired LazyTransaction")
}
LazyTransactionState::Acquired { tx } => tx,
LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"),
}
}

async fn acquire(&mut self) -> Result<(), Error> {
match &self.0 {
LazyTransactionState::Unacquired { state } => {
let tx = state.transaction().await?;
self.0 = LazyTransactionState::Acquired { tx };
Ok(())
}
LazyTransactionState::Acquired { .. } => Ok(()),
LazyTransactionState::Resolved => Err(Error::OverlappingExtractors),
}
}

pub(crate) async fn resolve(&mut self) -> Result<(), sqlx::Error> {
match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) {
LazyTransactionState::Unacquired { .. } | LazyTransactionState::Resolved => Ok(()),
LazyTransactionState::Acquired { tx } => tx.commit().await,
}
}

pub(crate) async fn commit(&mut self) -> Result<(), sqlx::Error> {
match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) {
LazyTransactionState::Unacquired { .. } => {
panic!("BUG: tried to commit unacquired transaction")
}
LazyTransactionState::Acquired { tx } => tx.commit().await,
LazyTransactionState::Resolved => panic!("BUG: tried to commit resolved transaction"),
}
}
}
15 changes: 8 additions & 7 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use std::marker::PhantomData;
use axum_core::response::IntoResponse;
use bytes::Bytes;
use futures_core::future::BoxFuture;
use http_body::{combinators::UnsyncBoxBody, Body};
use http_body::Body;

use crate::{tx::TxSlot, Marker, State};
use crate::{extension::Extension, Marker, State};

/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
///
Expand Down Expand Up @@ -97,7 +97,7 @@ where
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
type Response = http::Response<UnsyncBoxBody<ResBody::Data, axum_core::Error>>;
type Response = http::Response<axum_core::body::Body>;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

Expand All @@ -109,20 +109,21 @@ where
}

fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
let transaction = TxSlot::bind(req.extensions_mut(), self.state.clone());
let ext = Extension::new(self.state.clone());
req.extensions_mut().insert(ext.clone());

let res = self.inner.call(req);

Box::pin(async move {
let res = res.await.unwrap(); // inner service is infallible

if !res.status().is_server_error() && !res.status().is_client_error() {
if let Err(error) = transaction.commit().await {
if let Err(error) = ext.resolve().await {
return Ok(error.into().into_response());
}
}

Ok(res.map(|body| body.map_err(axum_core::Error::new).boxed_unsync()))
Ok(res.map(axum_core::body::Body::new))
})
}
}
Expand All @@ -145,6 +146,6 @@ mod tests {
.route("/", axum::routing::get(|| async { "hello" }))
.layer(layer);

axum::Server::bind(todo!()).serve(app.into_make_service());
axum::serve(todo!(), app);
}
}
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
//! # .route("/", axum::routing::get(|tx: Tx| async move {}))
//! .layer(layer)
//! .with_state(state);
//! # axum::Server::bind(todo!()).serve(app.into_make_service());
//! # axum::serve(todo!(), app);
//! # }
//! ```
//!
Expand Down Expand Up @@ -85,9 +85,9 @@

mod config;
mod error;
mod extension;
mod layer;
mod marker;
mod slot;
mod state;
mod tx;

Expand Down
2 changes: 1 addition & 1 deletion src/marker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use std::fmt::Debug;
/// .layer(layer1)
/// .layer(layer2)
/// .with_state(MyState { state1, state2 });
/// # axum::Server::bind(todo!()).serve(app.into_make_service());
/// # axum::serve(todo!(), app);
/// # }
/// ```
pub trait Marker: Debug + Send + Sized + 'static {
Expand Down
Loading

0 comments on commit 9a2f761

Please sign in to comment.