Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplication in serving with and without graceful shutdown #2803

Merged
merged 5 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ async fn logging_rejections() {
rejection_type: String,
}

let events = capture_tracing::<RejectionEvent, _, _>(|| async {
let events = capture_tracing::<RejectionEvent, _>(|| async {
let app = Router::new()
.route("/extension", get(|_: Extension<Infallible>| async {}))
.route("/string", post(|_: String| async {}));
Expand All @@ -987,6 +987,7 @@ async fn logging_rejections() {
StatusCode::BAD_REQUEST,
);
})
.with_filter("axum::rejection=trace")
.await;

assert_eq!(
Expand Down
57 changes: 2 additions & 55 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,61 +213,8 @@ where
type IntoFuture = private::ServeFuture;

fn into_future(self) -> Self::IntoFuture {
private::ServeFuture(Box::pin(async move {
let Self {
tcp_listener,
mut make_service,
tcp_nodelay,
_marker: _,
} = self;

loop {
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
Some(conn) => conn,
None => continue,
};

if let Some(nodelay) = tcp_nodelay {
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
}

let tcp_stream = TokioIo::new(tcp_stream);

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));

let hyper_service = TowerToHyperService::new(tower_service);

tokio::spawn(async move {
match Builder::new(TokioExecutor::new())
// upgrades needed for websockets
.serve_connection_with_upgrades(tcp_stream, hyper_service)
.await
{
Ok(()) => {}
Err(_err) => {
// This error only appears when the client doesn't send a request and
// terminate the connection.
//
// If client sends one request then terminate connection whenever, it doesn't
// appear.
}
}
});
}
}))
self.with_graceful_shutdown(std::future::pending())
.into_future()
}
}

Expand Down
96 changes: 68 additions & 28 deletions axum/src/test_helpers/tracing_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
use crate::util::AxumMutex;
use std::{future::Future, io, sync::Arc};
use std::{
future::{Future, IntoFuture},
io,
marker::PhantomData,
pin::Pin,
sync::Arc,
};

use serde::{de::DeserializeOwned, Deserialize};
use tracing::instrument::WithSubscriber;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{filter::Targets, fmt::MakeWriter};

Expand All @@ -14,36 +21,69 @@ pub(crate) struct TracingEvent<T> {
}

/// Run an async closure and capture the tracing output it produces.
pub(crate) async fn capture_tracing<T, F, Fut>(f: F) -> Vec<TracingEvent<T>>
pub(crate) fn capture_tracing<T, F>(f: F) -> CaptureTracing<T, F>
where
F: Fn() -> Fut,
Fut: Future,
T: DeserializeOwned,
{
let (make_writer, handle) = TestMakeWriter::new();

let subscriber = tracing_subscriber::registry().with(
tracing_subscriber::fmt::layer()
.with_writer(make_writer)
.with_target(true)
.without_time()
.with_ansi(false)
.json()
.flatten_event(false)
.with_filter("axum=trace".parse::<Targets>().unwrap()),
);

let guard = tracing::subscriber::set_default(subscriber);

f().await;

drop(guard);

handle
.take()
.lines()
.map(|line| serde_json::from_str(line).unwrap())
.collect()
CaptureTracing {
f,
filter: None,
_phantom: PhantomData,
}
}

pub(crate) struct CaptureTracing<T, F> {
f: F,
filter: Option<Targets>,
_phantom: PhantomData<fn() -> T>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why did you do this? I assume that since it's DeserializeOwned, the T will not be a reference so I don't see the point of going out of your way to make it contravariant. Am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fn(T) would be contravariant, fn() -> T is covariant. It's just good form for generic "output" parameters, doesn't actually matter here. See https://doc.rust-lang.org/nomicon/phantom-data.html#table-of-phantomdata-patterns for details.

}

impl<T, F> CaptureTracing<T, F> {
pub(crate) fn with_filter(mut self, filter_string: &str) -> Self {
self.filter = Some(filter_string.parse().unwrap());
self
}
}

impl<T, F, Fut> IntoFuture for CaptureTracing<T, F>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future + Send,
T: DeserializeOwned,
{
type Output = Vec<TracingEvent<T>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;

fn into_future(self) -> Self::IntoFuture {
let Self { f, filter, .. } = self;
Box::pin(async move {
let (make_writer, handle) = TestMakeWriter::new();

let filter = filter.unwrap_or_else(|| "axum=trace".parse().unwrap());
let subscriber = tracing_subscriber::registry().with(
tracing_subscriber::fmt::layer()
.with_writer(make_writer)
.with_target(true)
.without_time()
.with_ansi(false)
.json()
.flatten_event(false)
.with_filter(filter),
);

let guard = tracing::subscriber::set_default(subscriber);

f().with_current_subscriber().await;

drop(guard);

handle
.take()
.lines()
.map(|line| serde_json::from_str(line).unwrap())
.collect()
})
}
}

struct TestMakeWriter {
Expand Down
Loading