-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: feat: SetHost, Http1RequestTarget and DelayedResposne middlewares
- Loading branch information
Showing
6 changed files
with
201 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ pub mod connect; | |
pub mod legacy; | ||
#[doc(hidden)] | ||
pub mod pool; | ||
pub mod services; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
use std::{ | ||
future, | ||
pin::Pin, | ||
task::{Context, Poll}, | ||
}; | ||
|
||
use futures_channel::oneshot; | ||
use futures_util::Future; | ||
use http::Response; | ||
use hyper::body::Body; | ||
use hyper::service::Service; | ||
use pin_project_lite::pin_project; | ||
|
||
pub struct DelayedResponse<S> { | ||
inner: S, | ||
} | ||
|
||
impl<S, Req, B> Service<Req> for DelayedResponse<S> | ||
where | ||
S: Service<Req, Response = Response<B>>, | ||
B: Body, | ||
{ | ||
type Response = S::Response; | ||
type Error = S::Error; | ||
type Future = DelayedResponseFuture<S::Future>; | ||
|
||
fn call(&self, req: Req) -> Self::Future { | ||
DelayedResponseFuture { | ||
inner: self.inner.call(req), | ||
} | ||
} | ||
} | ||
|
||
pin_project! { | ||
struct DelayedResponseFuture<F> { | ||
#[pin] | ||
inner: F, | ||
} | ||
} | ||
|
||
impl<F, E, B> Future for DelayedResponseFuture<F> | ||
where | ||
F: Future<Output = Result<Response<B>, E>>, | ||
B: Body, | ||
{ | ||
type Output = F::Output; | ||
|
||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||
match self.project().inner.poll(cx) { | ||
Poll::Ready(res) => { | ||
let res = res?; | ||
let (delayed_tx, delayed_rx) = oneshot::channel(); | ||
res.body_mut().delayed_eof(delayed_rx); | ||
let on_idle = future::poll_fn(move |cx| pooled.poll_ready(cx)).map(move |_| { | ||
// At this point, `pooled` is dropped, and had a chance | ||
// to insert into the pool (if conn was idle) | ||
drop(delayed_tx); | ||
}); | ||
|
||
self.executor.execute(on_idle); | ||
Poll::Ready(Ok(res)) | ||
} | ||
Poll::Pending => Poll::Pending, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
use http::{uri::Scheme, Method, Request, Uri}; | ||
use hyper::service::Service; | ||
use tracing::warn; | ||
|
||
pub struct Http1RequestTarget<S> { | ||
inner: S, | ||
} | ||
|
||
impl<S, B> Service<Request<B>> for Http1RequestTarget<S> | ||
where | ||
S: Service<Request<B>>, | ||
{ | ||
type Response = S::Response; | ||
type Error = S::Error; | ||
type Future = S::Future; | ||
|
||
fn call(&self, mut req: Request<B>) -> Self::Future { | ||
// CONNECT always sends authority-form, so check it first... | ||
if req.method() == Method::CONNECT { | ||
authority_form(req.uri_mut()); | ||
// TODO: this middleware must be connection pool aware | ||
// } else if pooled.conn_info.is_proxied { | ||
// absolute_form(req.uri_mut()); | ||
} else { | ||
origin_form(req.uri_mut()); | ||
} | ||
self.inner.call(req) | ||
} | ||
} | ||
|
||
fn origin_form(uri: &mut Uri) { | ||
let path = match uri.path_and_query() { | ||
Some(path) if path.as_str() != "/" => { | ||
let mut parts = ::http::uri::Parts::default(); | ||
parts.path_and_query = Some(path.clone()); | ||
Uri::from_parts(parts).expect("path is valid uri") | ||
} | ||
_none_or_just_slash => { | ||
debug_assert!(Uri::default() == "/"); | ||
Uri::default() | ||
} | ||
}; | ||
*uri = path | ||
} | ||
|
||
fn absolute_form(uri: &mut Uri) { | ||
debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); | ||
debug_assert!( | ||
uri.authority().is_some(), | ||
"absolute_form needs an authority" | ||
); | ||
// If the URI is to HTTPS, and the connector claimed to be a proxy, | ||
// then it *should* have tunneled, and so we don't want to send | ||
// absolute-form in that case. | ||
if uri.scheme() == Some(&Scheme::HTTPS) { | ||
origin_form(uri); | ||
} | ||
} | ||
|
||
fn authority_form(uri: &mut Uri) { | ||
if let Some(path) = uri.path_and_query() { | ||
// `https://hyper.rs` would parse with `/` path, don't | ||
// annoy people about that... | ||
if path != "/" { | ||
warn!("HTTP/1.1 CONNECT request stripping path: {:?}", path); | ||
} | ||
} | ||
*uri = match uri.authority() { | ||
Some(auth) => { | ||
let mut parts = ::http::uri::Parts::default(); | ||
parts.authority = Some(auth.clone()); | ||
Uri::from_parts(parts).expect("authority is valid") | ||
} | ||
None => { | ||
unreachable!("authority_form with relative uri"); | ||
} | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mod delayed_response; | ||
mod http1_request_target; | ||
mod set_host; | ||
|
||
pub use delayed_response::DelayedResponse; | ||
pub use http1_request_target::Http1RequestTarget; | ||
pub use set_host::SetHost; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
use http::{header::HOST, uri::Port, HeaderValue, Request, Uri}; | ||
use hyper::service::Service; | ||
|
||
pub struct SetHost<S> { | ||
inner: S, | ||
} | ||
|
||
impl<S, B> Service<Request<B>> for SetHost<S> | ||
where | ||
S: Service<Request<B>>, | ||
{ | ||
type Response = S::Response; | ||
type Error = S::Error; | ||
type Future = S::Future; | ||
|
||
fn call(&self, mut req: Request<B>) -> Self::Future { | ||
let uri = req.uri().clone(); | ||
req.headers_mut().entry(HOST).or_insert_with(|| { | ||
let hostname = uri.host().expect("authority implies host"); | ||
if let Some(port) = get_non_default_port(&uri) { | ||
let s = format!("{}:{}", hostname, port); | ||
HeaderValue::from_str(&s) | ||
} else { | ||
HeaderValue::from_str(hostname) | ||
} | ||
.expect("uri host is valid header value") | ||
}); | ||
self.inner.call(req) | ||
} | ||
} | ||
|
||
fn get_non_default_port(uri: &Uri) -> Option<Port<&str>> { | ||
match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) { | ||
(Some(443), true) => None, | ||
(Some(80), false) => None, | ||
_ => uri.port(), | ||
} | ||
} | ||
|
||
fn is_schema_secure(uri: &Uri) -> bool { | ||
uri.scheme_str() | ||
.map(|scheme_str| matches!(scheme_str, "wss" | "https")) | ||
.unwrap_or_default() | ||
} |