Skip to content

Commit

Permalink
Merge pull request #3 from junkurihara/feat/hyper-1.0
Browse files Browse the repository at this point in the history
feat: hyper 1.0
  • Loading branch information
junkurihara authored Nov 17, 2023
2 parents 07e8717 + b18646a commit f235d79
Show file tree
Hide file tree
Showing 17 changed files with 312 additions and 161 deletions.
6 changes: 6 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[submodule "submodules/hyper-util"]
path = submodules/hyper-util
url = git@github.com:junkurihara/hyper-util.git
[submodule "submodules/hyper-tls"]
path = submodules/hyper-tls
url = git@github.com:junkurihara/hyper-tls.git
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]

members = ["relay-bin", "relay-lib"]
exclude = ["submodules/hyper-util/", "submodule/hyper-tls"]
resolver = "2"

[profile.release]
Expand Down
2 changes: 1 addition & 1 deletion relay-bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ hot_reload = "0.1.4"

# logging
tracing = { version = "0.1.40" }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
27 changes: 13 additions & 14 deletions relay-lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,20 @@ async-trait = "0.1.74"
# http handling
url = "2.4.1"
rustc-hash = "1.1.0"
hyper = { version = "0.14.27", default-features = false, features = [
"server",
"http1",
"http2",
"stream",
] }
hyper-rustls = { version = "0.24.2", default-features = false, features = [
"tokio-runtime",
"webpki-tokio",
"http1",
"http2",
] }
tokio-rustls = { version = "0.24.1", features = ["early-data"] }
hyper = { version = "1.0.0", default-features = false }
http = "1.0.0"
http-body-util = "0.1.0"
hyper-util = { path = "../submodules/hyper-util/", features = ["full"] }
hyper-tls = { path = "../submodules/hyper-tls/", default-features = false }
# hyper-rustls = { version = "0.24.2", default-features = false, features = [
# "tokio-runtime",
# "webpki-tokio",
# "http1",
# "http2",
# ] }

# validation of id token
reqwest = { version = "0.11.22", features = ["json", "default"] }
# reqwest = { version = "0.11.22", features = ["json", "default"] }
serde = { version = "1.0.192", default-features = false }
auth-validator = { git = "https://github.com/junkurihara/rust-token-server", package = "rust-token-server-validator", branch = "develop" }
serde_json = { version = "1.0.108" }
2 changes: 2 additions & 0 deletions relay-lib/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ pub const MAX_DNS_QUESTION_LEN: usize = 512;
pub const JWKS_REFETCH_DELAY_SEC: u64 = 300;
/// HTTP request timeout for refetching JWKS
pub const JWKS_REFETCH_TIMEOUT_SEC: u64 = 3;
/// Expected maximum size of JWKS in bytes
pub const EXPECTED_MAX_JWKS_SIZE: u64 = 1024 * 64;
6 changes: 2 additions & 4 deletions relay-lib/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub use anyhow::{anyhow, bail, ensure, Context};
use hyper::StatusCode;
use http::StatusCode;
use thiserror::Error;

pub type Result<T> = std::result::Result<T, RelayError>;
Expand All @@ -14,8 +14,6 @@ pub enum RelayError {
NoValidator,
#[error("Failed to build forwarder")]
BuildForwarderError,
#[error("Failed to build validator")]
BuildValidatorError,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
Expand Down Expand Up @@ -55,7 +53,7 @@ pub enum HttpError {
TooLargeRequestBody,

#[error("Failed to send request")]
SendRequestError(#[from] hyper::Error),
SendRequestError(#[from] hyper_util::client::legacy::Error),
#[error("Invalid response content type")]
InvalidResponseContentType,
#[error("Invalid response body")]
Expand Down
60 changes: 60 additions & 0 deletions relay-lib/src/http_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use crate::hyper_executor::LocalExecutor;
use http::Request;
use hyper::body::Body;
use hyper_tls::HttpsConnector;
use hyper_util::client::{
connect::{Connect, HttpConnector},
legacy::{Client, ResponseFuture},
};

#[derive(Clone)]
/// Http client that is used for forwarding requests to upstream and fetching jwks from auth server.
pub struct HttpClient<C, B>
where
C: Send + Sync + Connect + Clone + 'static,
B: Body + Send + Unpin + 'static,
<B as Body>::Data: Send,
<B as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>,
{
pub inner: Client<C, B>,
}

impl<C, B> HttpClient<C, B>
where
C: Send + Sync + Connect + Clone + 'static,
B: Body + Send + Unpin + 'static,
<B as Body>::Data: Send,
<B as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>,
{
/// wrapper request fn
pub fn request(&self, req: Request<B>) -> ResponseFuture {
self.inner.request(req)
}
}

impl<B> HttpClient<HttpsConnector<HttpConnector>, B>
where
B: Body + Send + Unpin + 'static,
<B as Body>::Data: Send,
<B as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>,
{
/// Build inner client with hyper-tls
pub fn new(runtime_handle: tokio::runtime::Handle) -> Self {
// build hyper client with hyper-tls, only https is allowed
let mut connector = HttpsConnector::new();
connector.https_only(true);
let executor = LocalExecutor::new(runtime_handle.clone());
let inner = Client::builder(executor).build::<_, B>(connector);

// build hyper client with rustls and webpki, only https is allowed
// let connector = hyper_rustls::HttpsConnectorBuilder::new()
// .with_webpki_roots()
// .https_only()
// .enable_http1()
// .enable_http2()
// .build();
// let inner = Client::builder(TokioExecutor::new()).build::<_, B>(connector);

Self { inner }
}
}
23 changes: 23 additions & 0 deletions relay-lib/src/hyper_executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use tokio::runtime::Handle;

#[derive(Clone)]
/// Executor for hyper
pub struct LocalExecutor {
runtime_handle: Handle,
}

impl LocalExecutor {
pub fn new(runtime_handle: Handle) -> Self {
LocalExecutor { runtime_handle }
}
}

impl<F> hyper::rt::Executor<F> for LocalExecutor
where
F: std::future::Future + Send + 'static,
F::Output: Send,
{
fn execute(&self, fut: F) {
self.runtime_handle.spawn(fut);
}
}
2 changes: 2 additions & 0 deletions relay-lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod constants;
mod error;
mod globals;
mod http_client;
mod hyper_executor;
mod log;
mod relay;
mod validator;
Expand Down
100 changes: 54 additions & 46 deletions relay-lib/src/relay/forwarder.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use crate::{constants::*, error::*, globals::Globals, log::*};
use futures::stream::StreamExt;
use hyper::{
client::{connect::Connect, HttpConnector},
header,
http::{request::Parts, HeaderMap, HeaderValue, Method},
Body, Client, Request, Response,
use crate::{constants::*, error::*, globals::Globals, http_client::HttpClient, log::*};
use http::{
header::{self, HeaderMap, HeaderValue},
request::Parts,
Method, Request, Response,
};
use hyper_rustls::HttpsConnector;
use hyper::body::{Body, Incoming};
use hyper_tls::HttpsConnector;
use hyper_util::client::connect::{Connect, HttpConnector};
use std::{net::SocketAddr, sync::Arc};
use url::Url;

/// parse and check content type and accept headers if both or either of them are "application/oblivious-dns-message".
fn check_content_type<T>(req: &Request<T>) -> HttpResult<()> {
fn check_content_type<B>(req: &Request<B>) -> HttpResult<()> {
// check content type
if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
let Ok(ct) = content_type.to_str() else {
Expand Down Expand Up @@ -40,24 +40,34 @@ fn check_content_type<T>(req: &Request<T>) -> HttpResult<()> {
}

/// Read encrypted query from request body
async fn read_body(body: &mut Body) -> HttpResult<Vec<u8>> {
let mut sum_size = 0;
let mut query = vec![];
while let Some(chunk) = body.next().await {
let chunk = chunk.map_err(|_| HttpError::TooLargeRequestBody)?;
sum_size += chunk.len();
if sum_size >= MAX_DNS_QUESTION_LEN {
return Err(HttpError::TooLargeRequestBody);
}
query.extend(chunk);
async fn inspect_request_body<B: Body>(body: &B) -> HttpResult<()> {
let max = body.size_hint().upper().unwrap_or(u64::MAX);
if max > MAX_DNS_QUESTION_LEN as u64 {
return Err(HttpError::TooLargeRequestBody);
}
if max == 0 {
return Err(HttpError::NoBodyInRequest);
}
Ok(query)
// Ok(EitherBody::Left(body))
Ok(())

// let mut sum_size = 0;
// let mut query = vec![];
// while let Some(chunk) = body.next().await {
// let chunk = chunk.map_err(|_| HttpError::TooLargeRequestBody)?;
// sum_size += chunk.len();
// if sum_size >= MAX_DNS_QUESTION_LEN {
// return Err(HttpError::TooLargeRequestBody);
// }
// query.extend(chunk);
// }
// Ok(query)
}

/// Get HOST header and/or host name in url line in http request
/// Returns Err if both are specified and inconsistent, or if none of them is specified.
/// Note that port is dropped even if specified.
fn inspect_get_host(req: &Request<Body>) -> HttpResult<String> {
fn inspect_get_host<B>(req: &Request<B>) -> HttpResult<String> {
let drop_port = |v: &str| {
v.split(':')
.next()
Expand All @@ -81,13 +91,16 @@ fn inspect_get_host(req: &Request<Body>) -> HttpResult<String> {
}
}

/// wrapper of reqwest client
pub struct InnerForwarder<C, B = Body>
/// wrapper of http client
pub struct InnerForwarder<C, B = Incoming>
where
C: Send + Sync + Connect + Clone + 'static,
B: Body + Send + Unpin + 'static,
<B as Body>::Data: Send,
<B as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>,
{
/// hyper client
pub(super) inner: Client<C, B>,
pub(super) inner: HttpClient<C, B>,
/// request default headers
pub(super) request_headers: HeaderMap,
/// relay host name
Expand All @@ -98,9 +111,12 @@ where
pub(super) max_subseq_nodes: usize,
}

impl<C> InnerForwarder<C>
impl<C, B> InnerForwarder<C, B>
where
C: Send + Sync + Connect + Clone + 'static,
B: Body + Send + Unpin + 'static,
<B as Body>::Data: Send,
<B as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>,
{
/// Serve request as relay
/// 1. check host, method and listening path: as described in [RFC9230](https://datatracker.ietf.org/doc/rfc9230/) and Golang implementation [odoh-server-go](https://github.com/cloudflare/odoh-server-go), only post method is allowed.
Expand All @@ -112,10 +128,10 @@ where
/// c.f., "Proxy-Status" [RFC9209](https://datatracker.ietf.org/doc/rfc9209).
pub async fn serve(
&self,
req: Request<Body>,
req: Request<B>,
peer_addr: SocketAddr,
validation_passed: bool,
) -> HttpResult<Response<Body>> {
) -> HttpResult<Response<Incoming>> {
// TODO: source ip access control here?
// for authorized ip addresses, maintain blacklist (error metrics) at each relay for given requests

Expand Down Expand Up @@ -156,16 +172,15 @@ where
// for authorized domains, maintain blacklist (error metrics) at each relay for given responses

// split request into parts and body to manipulate them later
let (mut parts, mut body) = req.into_parts();
// check if body is a valid odoh query and serve it
let encrypted_query = read_body(&mut body).await?;
if encrypted_query.is_empty() {
return Err(HttpError::NoBodyInRequest);
}
let (mut parts, body) = req.into_parts();
// check if body does not exceed max size as a DNS query
inspect_request_body(&body).await?;

// Forward request to next hop: Only post method is allowed in ODoH
self.update_request_parts(&nexthop_url, &mut parts)?;
let updated_request = Request::from_parts(parts, Body::from(encrypted_query.to_owned()));
let updated_request = Request::from_parts(parts, body);

// let updated_request = Request::from_parts(parts, Body::from(encrypted_query.to_owned()));
let mut response = match self.inner.request(updated_request).await {
Ok(res) => res,
Err(e) => {
Expand All @@ -174,7 +189,7 @@ where
}
};
// Inspect and update response
self.inspect_update_response(&mut response)?;
self.inspect_and_update_response_header(&mut response)?;

Ok(response)
}
Expand All @@ -194,12 +209,12 @@ where
Ok(())
}

/// inspect and update response
/// inspect and update response header
/// (M)ODoH response MUST NOT be cached as specified in [RFC9230](https://datatracker.ietf.org/doc/rfc9230/),
/// and hence "no-cache, no-store" is set (overwritten) in cache-control header.
/// Also "Proxy-Status" header with a received-status param is appended (overwritten) as described in [RFC9230, Section 4.3](https://datatracker.ietf.org/doc/rfc9230/).
/// c.f., "Proxy-Status" [RFC9209](https://datatracker.ietf.org/doc/rfc9209).
fn inspect_update_response(&self, response: &mut Response<Body>) -> HttpResult<()> {
fn inspect_and_update_response_header<T>(&self, response: &mut Response<T>) -> HttpResult<()> {
let status = response.status();
let proxy_status = format!("received-status={}", status);

Expand All @@ -225,7 +240,7 @@ where
}
}

impl InnerForwarder<HttpsConnector<HttpConnector>, Body> {
impl InnerForwarder<HttpsConnector<HttpConnector>> {
/// Build inner forwarder
pub fn try_new(globals: &Arc<Globals>) -> Result<Self> {
// default headers for request
Expand All @@ -239,14 +254,7 @@ impl InnerForwarder<HttpsConnector<HttpConnector>, Body> {
request_headers.insert(header::CACHE_CONTROL, HeaderValue::from_static(ODOH_CACHE_CONTROL));
request_headers.insert(header::USER_AGENT, user_agent);

// build hyper client with rustls and webpki, only https is allowed
let connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_webpki_roots()
.https_only()
.enable_http1()
.enable_http2()
.build();
let inner = Client::builder().build::<_, Body>(connector);
let inner = HttpClient::new(globals.runtime_handle.clone());

let relay_host = globals.relay_config.hostname.clone();
let relay_path = globals.relay_config.path.clone();
Expand Down
Loading

0 comments on commit f235d79

Please sign in to comment.