From 7d512f9145cdb1c815e2e2225b19fb45ac5f8fd6 Mon Sep 17 00:00:00 2001 From: Tomek Karwowski Date: Mon, 7 Aug 2023 20:33:58 +0200 Subject: [PATCH] WIP: feat: SetHost, Http1RequestTarget and DelayedResposne middlewares --- Cargo.toml | 10 ++- src/client/mod.rs | 1 + src/client/services/delayed_response.rs | 71 +++++++++++++++++ src/client/services/http1_request_target.rs | 85 +++++++++++++++++++++ src/client/services/mod.rs | 7 ++ src/client/services/set_host.rs | 50 ++++++++++++ src/common/mod.rs | 2 - 7 files changed, 221 insertions(+), 5 deletions(-) create mode 100644 src/client/services/delayed_response.rs create mode 100644 src/client/services/http1_request_target.rs create mode 100644 src/client/services/mod.rs create mode 100644 src/client/services/set_host.rs diff --git a/Cargo.toml b/Cargo.toml index b98b56f..71d4e40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,11 @@ repository = "https://github.com/hyperium/hyper-util" license = "MIT" authors = ["Sean McArthur "] keywords = ["http", "hyper", "hyperium"] -categories = ["network-programming", "web-programming::http-client", "web-programming::http-server"] +categories = [ + "network-programming", + "web-programming::http-client", + "web-programming::http-server", +] edition = "2018" publish = false # no accidents while in dev @@ -35,8 +39,8 @@ pnet_datalink = "0.27.2" [features] runtime = [] tcp = [] -http1 = [] -http2 = [] +http1 = ["hyper/http1"] +http2 = ["hyper/http2"] # internal features used in CI __internal_happy_eyeballs_tests = [] diff --git a/src/client/mod.rs b/src/client/mod.rs index 7b5210a..50355ae 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,3 +4,4 @@ pub mod client; pub mod connect; pub mod pool; +pub mod services; diff --git a/src/client/services/delayed_response.rs b/src/client/services/delayed_response.rs new file mode 100644 index 0000000..9770974 --- /dev/null +++ b/src/client/services/delayed_response.rs @@ -0,0 +1,71 @@ +use std::{ + future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_channel::oneshot; +use futures_util::Future; +use http::Response; +use hyper::Body; +use pin_project_lite::pin_project; +use tower_service::Service; + +pub struct DelayedResponse { + inner: S, +} + +impl Service for DelayedResponse +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = DelayedResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Req) -> Self::Future { + DelayedResponseFuture { + inner: self.inner.call(req), + } + } +} + +pin_project! { + struct DelayedResponseFuture { + #[pin] + inner: F, + } +} + +impl Future for DelayedResponseFuture +where + F: Future, E>>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().inner.poll(cx) { + Poll::Ready(res) => { + let res = res?; + let (delayed_tx, delayed_rx) = oneshot::channel(); + res.body_mut().delayed_eof(delayed_rx); + let on_idle = future::poll_fn(move |cx| pooled.poll_ready(cx)).map(move |_| { + // At this point, `pooled` is dropped, and had a chance + // to insert into the pool (if conn was idle) + drop(delayed_tx); + }); + + self.executor.execute(on_idle); + Poll::Ready(Ok(res)) + } + Poll::Pending => Poll::Pending, + } + } +} diff --git a/src/client/services/http1_request_target.rs b/src/client/services/http1_request_target.rs new file mode 100644 index 0000000..d28b571 --- /dev/null +++ b/src/client/services/http1_request_target.rs @@ -0,0 +1,85 @@ +use http::{uri::Scheme, Method, Request, Uri}; +use tower_service::Service; +use tracing::warn; + +pub struct Http1RequestTarget { + inner: S, +} + +impl Service> for Http1RequestTarget +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + // CONNECT always sends authority-form, so check it first... + if req.method() == Method::CONNECT { + authority_form(req.uri_mut()); + // TODO: this middleware must be connection pool aware + // } else if pooled.conn_info.is_proxied { + // absolute_form(req.uri_mut()); + } else { + origin_form(req.uri_mut()); + } + self.inner.call(req) + } +} + +fn origin_form(uri: &mut Uri) { + let path = match uri.path_and_query() { + Some(path) if path.as_str() != "/" => { + let mut parts = ::http::uri::Parts::default(); + parts.path_and_query = Some(path.clone()); + Uri::from_parts(parts).expect("path is valid uri") + } + _none_or_just_slash => { + debug_assert!(Uri::default() == "/"); + Uri::default() + } + }; + *uri = path +} + +fn absolute_form(uri: &mut Uri) { + debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); + debug_assert!( + uri.authority().is_some(), + "absolute_form needs an authority" + ); + // If the URI is to HTTPS, and the connector claimed to be a proxy, + // then it *should* have tunneled, and so we don't want to send + // absolute-form in that case. + if uri.scheme() == Some(&Scheme::HTTPS) { + origin_form(uri); + } +} + +fn authority_form(uri: &mut Uri) { + if let Some(path) = uri.path_and_query() { + // `https://hyper.rs` would parse with `/` path, don't + // annoy people about that... + if path != "/" { + warn!("HTTP/1.1 CONNECT request stripping path: {:?}", path); + } + } + *uri = match uri.authority() { + Some(auth) => { + let mut parts = ::http::uri::Parts::default(); + parts.authority = Some(auth.clone()); + Uri::from_parts(parts).expect("authority is valid") + } + None => { + unreachable!("authority_form with relative uri"); + } + }; +} diff --git a/src/client/services/mod.rs b/src/client/services/mod.rs new file mode 100644 index 0000000..1333bfb --- /dev/null +++ b/src/client/services/mod.rs @@ -0,0 +1,7 @@ +mod delayed_response; +mod http1_request_target; +mod set_host; + +pub use delayed_response::DelayedResponse; +pub use http1_request_target::Http1RequestTarget; +pub use set_host::SetHost; diff --git a/src/client/services/set_host.rs b/src/client/services/set_host.rs new file mode 100644 index 0000000..1f3838c --- /dev/null +++ b/src/client/services/set_host.rs @@ -0,0 +1,50 @@ +use std::task::{Context, Poll}; + +use http::{header::HOST, uri::Port, HeaderValue, Request, Uri}; +use hyper::service::Service; + +pub struct SetHost { + inner: S, +} + +impl Service> for SetHost +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let uri = req.uri().clone(); + req.headers_mut().entry(HOST).or_insert_with(|| { + let hostname = uri.host().expect("authority implies host"); + if let Some(port) = get_non_default_port(&uri) { + let s = format!("{}:{}", hostname, port); + HeaderValue::from_str(&s) + } else { + HeaderValue::from_str(hostname) + } + .expect("uri host is valid header value") + }); + self.inner.call(req) + } +} + +fn get_non_default_port(uri: &Uri) -> Option> { + match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) { + (Some(443), true) => None, + (Some(80), false) => None, + _ => uri.port(), + } +} + +fn is_schema_secure(uri: &Uri) -> bool { + uri.scheme_str() + .map(|scheme_str| matches!(scheme_str, "wss" | "https")) + .unwrap_or_default() +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 52b9917..82b29d9 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -12,5 +12,3 @@ macro_rules! ready { pub(crate) use ready; pub(crate) mod exec; pub(crate) mod never; - -pub(crate) use never::Never;