diff --git a/Cargo.toml b/Cargo.toml index a421a7f..f52cb1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ pin-project-lite = "0.2.4" futures-channel = { version = "0.3", optional = true } socket2 = { version = "0.5", optional = true, features = ["all"] } tracing = { version = "0.1", default-features = false, features = ["std"], optional = true } -tokio = { version = "1", optional = true, features = ["net", "rt", "time"] } +tokio = { version = "1", optional = true, default-features = false } tower-service ={ version = "0.3", optional = true } tower = { version = "0.4.1", optional = true, features = ["make", "util"] } @@ -57,7 +57,7 @@ full = [ ] client = ["hyper/client", "dep:tracing", "dep:futures-channel", "dep:tower", "dep:tower-service"] -client-legacy = ["client", "dep:socket2"] +client-legacy = ["client", "dep:socket2", "tokio/sync"] server = ["hyper/server"] server-auto = ["server", "http1", "http2"] @@ -67,7 +67,7 @@ service = ["dep:tower", "dep:tower-service"] http1 = ["hyper/http1"] http2 = ["hyper/http2"] -tokio = ["dep:tokio"] +tokio = ["dep:tokio", "tokio/net", "tokio/rt", "tokio/time"] # internal features used in CI __internal_happy_eyeballs_tests = [] diff --git a/src/client/legacy/client.rs b/src/client/legacy/client.rs index e50dca2..166c572 100644 --- a/src/client/legacy/client.rs +++ b/src/client/legacy/client.rs @@ -18,6 +18,7 @@ use hyper::rt::Timer; use hyper::{body::Body, Method, Request, Response, Uri, Version}; use tracing::{debug, trace, warn}; +use super::connect::capture::CaptureConnectionExtension; #[cfg(feature = "tokio")] use super::connect::HttpConnector; use super::connect::{Alpn, Connect, Connected, Connection}; @@ -265,6 +266,10 @@ where ) -> Result, Error> { let mut pooled = self.connection_for(pool_key).await?; + req.extensions_mut() + .get_mut::() + .map(|conn| conn.set(&pooled.conn_info)); + if pooled.is_http1() { if req.version() == Version::HTTP_2 { warn!("Connection is HTTP/1, but request requires HTTP/2"); diff --git a/src/client/legacy/connect/capture.rs b/src/client/legacy/connect/capture.rs new file mode 100644 index 0000000..4fbe384 --- /dev/null +++ b/src/client/legacy/connect/capture.rs @@ -0,0 +1,191 @@ +use std::{ops::Deref, sync::Arc}; + +use http::Request; +use tokio::sync::watch; + +use super::Connected; + +/// [`CaptureConnection`] allows callers to capture [`Connected`] information +/// +/// To capture a connection for a request, use [`capture_connection`]. +#[derive(Debug, Clone)] +pub struct CaptureConnection { + rx: watch::Receiver>, +} + +/// Capture the connection for a given request +/// +/// When making a request with Hyper, the underlying connection must implement the [`Connection`] trait. +/// [`capture_connection`] allows a caller to capture the returned [`Connected`] structure as soon +/// as the connection is established. +/// +/// *Note*: If establishing a connection fails, [`CaptureConnection::connection_metadata`] will always return none. +/// +/// # Examples +/// +/// **Synchronous access**: +/// The [`CaptureConnection::connection_metadata`] method allows callers to check if a connection has been +/// established. This is ideal for situations where you are certain the connection has already +/// been established (e.g. after the response future has already completed). +/// ```rust +/// use hyper_util::client::legacy::connect::capture_connection; +/// let mut request = http::Request::builder() +/// .uri("http://foo.com") +/// .body(()) +/// .unwrap(); +/// +/// let captured_connection = capture_connection(&mut request); +/// // some time later after the request has been sent... +/// let connection_info = captured_connection.connection_metadata(); +/// println!("we are connected! {:?}", connection_info.as_ref()); +/// ``` +/// +/// **Asynchronous access**: +/// The [`CaptureConnection::wait_for_connection_metadata`] method returns a future resolves as soon as the +/// connection is available. +/// +/// ```rust +/// # #[cfg(feature = "tokio")] +/// # async fn example() { +/// use hyper_util::client::legacy::connect::capture_connection; +/// use hyper_util::client::legacy::Client; +/// use hyper_util::rt::TokioExecutor; +/// use bytes::Bytes; +/// use http_body_util::Empty; +/// let mut request = http::Request::builder() +/// .uri("http://foo.com") +/// .body(Empty::::new()) +/// .unwrap(); +/// +/// let mut captured = capture_connection(&mut request); +/// tokio::task::spawn(async move { +/// let connection_info = captured.wait_for_connection_metadata().await; +/// println!("we are connected! {:?}", connection_info.as_ref()); +/// }); +/// +/// let client = Client::builder(TokioExecutor::new()).build_http(); +/// client.request(request).await.expect("request failed"); +/// # } +/// ``` +pub fn capture_connection(request: &mut Request) -> CaptureConnection { + let (tx, rx) = CaptureConnection::new(); + request.extensions_mut().insert(tx); + rx +} + +/// TxSide for [`CaptureConnection`] +/// +/// This is inserted into `Extensions` to allow Hyper to back channel connection info +#[derive(Clone)] +pub(crate) struct CaptureConnectionExtension { + tx: Arc>>, +} + +impl CaptureConnectionExtension { + pub(crate) fn set(&self, connected: &Connected) { + self.tx.send_replace(Some(connected.clone())); + } +} + +impl CaptureConnection { + /// Internal API to create the tx and rx half of [`CaptureConnection`] + pub(crate) fn new() -> (CaptureConnectionExtension, Self) { + let (tx, rx) = watch::channel(None); + ( + CaptureConnectionExtension { tx: Arc::new(tx) }, + CaptureConnection { rx }, + ) + } + + /// Retrieve the connection metadata, if available + pub fn connection_metadata(&self) -> impl Deref> + '_ { + self.rx.borrow() + } + + /// Wait for the connection to be established + /// + /// If a connection was established, this will always return `Some(...)`. If the request never + /// successfully connected (e.g. DNS resolution failure), this method will never return. + pub async fn wait_for_connection_metadata( + &mut self, + ) -> impl Deref> + '_ { + if self.rx.borrow().is_some() { + return self.rx.borrow(); + } + let _ = self.rx.changed().await; + self.rx.borrow() + } +} + +#[cfg(all(test, not(miri)))] +mod test { + use super::*; + + #[test] + fn test_sync_capture_connection() { + let (tx, rx) = CaptureConnection::new(); + assert!( + rx.connection_metadata().is_none(), + "connection has not been set" + ); + tx.set(&Connected::new().proxy(true)); + assert_eq!( + rx.connection_metadata() + .as_ref() + .expect("connected should be set") + .is_proxied(), + true + ); + + // ensure it can be called multiple times + assert_eq!( + rx.connection_metadata() + .as_ref() + .expect("connected should be set") + .is_proxied(), + true + ); + } + + #[tokio::test] + async fn async_capture_connection() { + let (tx, mut rx) = CaptureConnection::new(); + assert!( + rx.connection_metadata().is_none(), + "connection has not been set" + ); + let test_task = tokio::spawn(async move { + assert_eq!( + rx.wait_for_connection_metadata() + .await + .as_ref() + .expect("connection should be set") + .is_proxied(), + true + ); + // can be awaited multiple times + assert!( + rx.wait_for_connection_metadata().await.is_some(), + "should be awaitable multiple times" + ); + + assert_eq!(rx.connection_metadata().is_some(), true); + }); + // can't be finished, we haven't set the connection yet + assert_eq!(test_task.is_finished(), false); + tx.set(&Connected::new().proxy(true)); + + assert!(test_task.await.is_ok()); + } + + #[tokio::test] + async fn capture_connection_sender_side_dropped() { + let (tx, mut rx) = CaptureConnection::new(); + assert!( + rx.connection_metadata().is_none(), + "connection has not been set" + ); + drop(tx); + assert!(rx.wait_for_connection_metadata().await.is_none()); + } +} diff --git a/src/client/legacy/connect/mod.rs b/src/client/legacy/connect/mod.rs index 104ddd8..bd00baa 100644 --- a/src/client/legacy/connect/mod.rs +++ b/src/client/legacy/connect/mod.rs @@ -74,6 +74,9 @@ pub mod dns; #[cfg(feature = "tokio")] mod http; +pub(crate) mod capture; +pub use capture::{capture_connection, CaptureConnection}; + pub use self::sealed::Connect; /// Describes a type returned by a connector. @@ -169,7 +172,6 @@ impl Connected { // Don't public expose that `Connected` is `Clone`, unsure if we want to // keep that contract... - #[cfg(feature = "http2")] pub(super) fn clone(&self) -> Connected { Connected { alpn: self.alpn, diff --git a/tests/legacy_client.rs b/tests/legacy_client.rs index 3aab054..28babd7 100644 --- a/tests/legacy_client.rs +++ b/tests/legacy_client.rs @@ -18,7 +18,7 @@ use http_body_util::{Empty, Full, StreamBody}; use hyper::body::Bytes; use hyper::body::Frame; use hyper::Request; -use hyper_util::client::legacy::connect::HttpConnector; +use hyper_util::client::legacy::connect::{capture_connection, HttpConnector}; use hyper_util::client::legacy::Client; use hyper_util::rt::{TokioExecutor, TokioIo}; @@ -876,3 +876,35 @@ fn alpn_h2() { ); drop(client); } + +#[cfg(not(miri))] +#[test] +fn capture_connection_on_client() { + let _ = pretty_env_logger::try_init(); + + let rt = runtime(); + let connector = DebugConnector::new(); + + let client = Client::builder(TokioExecutor::new()).build(connector); + + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + thread::spawn(move || { + let mut sock = server.accept().unwrap().0; + //drop(server); + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))) + .unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + .expect("write 1"); + }); + let mut req = Request::builder() + .uri(&*format!("http://{}/a", addr)) + .body(Empty::::new()) + .unwrap(); + let captured_conn = capture_connection(&mut req); + rt.block_on(client.request(req)).expect("200 OK"); + assert!(captured_conn.connection_metadata().is_some()); +}