diff --git a/Cargo.lock b/Cargo.lock index d8317e9..515ddf9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -542,6 +542,7 @@ version = "0.0.1" dependencies = [ "futures", "hex-conservative", + "http", "http-body-util", "hyper", "hyper-rustls", diff --git a/Cargo.toml b/Cargo.toml index b95ea07..a2674bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"] [dependencies] futures = { version = "0.3", optional = true } +http = "1" http-body-util = "0.1" hyper = { version = "1", features = ["http1", "server"] } hyper-rustls = { version = "0.26", features = ["webpki-roots"] } diff --git a/src/bootstrap/connect.rs b/src/bootstrap/connect.rs index 8c1be72..154b3d9 100644 --- a/src/bootstrap/connect.rs +++ b/src/bootstrap/connect.rs @@ -1,5 +1,7 @@ +use std::net::SocketAddr; use std::sync::Arc; +use http::Uri; use http_body_util::combinators::BoxBody; use hyper::body::{Bytes, Incoming}; use hyper::upgrade::Upgraded; @@ -7,8 +9,8 @@ use hyper::{Method, Request, Response}; use hyper_util::rt::TokioIo; use tokio::net::TcpStream; -use crate::empty; use crate::error::Error; +use crate::{empty, uri_to_addr}; pub(crate) fn is_connect_request(req: &Request) -> bool { Method::CONNECT == req.method() @@ -16,7 +18,7 @@ pub(crate) fn is_connect_request(req: &Request) -> bool { pub(crate) async fn try_upgrade( req: Request, - gateway_origin: Arc, + gateway_origin: Arc, ) -> Result>, Error> { if let Some(addr) = find_allowable_gateway(&req, &gateway_origin) { tokio::task::spawn(async move { @@ -38,7 +40,7 @@ pub(crate) async fn try_upgrade( /// Create a TCP connection to host:port, build a tunnel between the connection and /// the upgraded connection -async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { +async fn tunnel(upgraded: Upgraded, addr: SocketAddr) -> std::io::Result<()> { let mut server = TcpStream::connect(addr).await?; let mut upgraded = TokioIo::new(upgraded); let (_, _) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?; @@ -48,36 +50,34 @@ async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { /// Only allow CONNECT requests to the configured OHTTP gateway authority. /// This prevents the relay from being used as an arbitrary proxy /// to any host on the internet. -fn find_allowable_gateway(req: &Request, gateway_origin: &str) -> Option { - let gateway_authority = - gateway_origin.trim_start_matches("https://").trim_start_matches("http://"); - let req_gateway = req.uri().authority().map(|auth| auth.to_string()); - if req_gateway == Some(gateway_authority.to_string()) { - req_gateway - } else { - None +fn find_allowable_gateway(req: &Request, gateway_origin: &Uri) -> Option { + if req.uri().authority() != gateway_origin.authority() { + return None; } + + uri_to_addr(gateway_origin) } #[cfg(test)] mod test { use hyper::Request; + use once_cell::sync::Lazy; use super::*; + static GATEWAY_ORIGIN: Lazy = Lazy::new(|| Uri::from_static("https://gateway.com")); + #[test] fn mismatched_gateways_not_allowed() { - let gateway_origin = "https://gateway.com"; let not_gateway_origin = "https://not-gateway.com"; let req = hyper::Request::builder().uri(not_gateway_origin).body(()).unwrap(); - let allowable_gateway = find_allowable_gateway(&req, gateway_origin); + let allowable_gateway = find_allowable_gateway(&req, &*GATEWAY_ORIGIN); assert!(allowable_gateway.is_none()); } #[test] fn matched_gateways_allowed() { - let gateway_origin = "https://gateway.com"; - let req = Request::builder().uri(gateway_origin).body(()).unwrap(); - assert!(find_allowable_gateway(&req, gateway_origin).is_some()); + let req = Request::builder().uri(&*GATEWAY_ORIGIN).body(()).unwrap(); + assert!(find_allowable_gateway(&req, &*GATEWAY_ORIGIN).is_some()); } } diff --git a/src/bootstrap/mod.rs b/src/bootstrap/mod.rs index 3ab8343..763a34b 100644 --- a/src/bootstrap/mod.rs +++ b/src/bootstrap/mod.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use http::Uri; use http_body_util::combinators::BoxBody; use hyper::body::{Bytes, Incoming}; use hyper::{Request, Response}; @@ -14,7 +15,7 @@ pub mod ws; pub(crate) async fn handle_ohttp_keys( mut req: Request, - gateway_origin: Arc, + gateway_origin: Arc, ) -> Result>, Error> { #[cfg(feature = "connect-bootstrap")] if connect::is_connect_request(&req) { diff --git a/src/bootstrap/ws.rs b/src/bootstrap/ws.rs index 08fc298..ab3c33d 100644 --- a/src/bootstrap/ws.rs +++ b/src/bootstrap/ws.rs @@ -1,9 +1,11 @@ use std::io; +use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use futures::{Sink, SinkExt, StreamExt}; +use http::Uri; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; use hyper::body::{Bytes, Incoming}; @@ -14,6 +16,7 @@ use tokio_tungstenite::tungstenite::protocol::Message; use tokio_tungstenite::{tungstenite, WebSocketStream}; use crate::error::Error; +use crate::uri_to_addr; pub(crate) fn is_websocket_request(req: &Request) -> bool { hyper_tungstenite::is_upgrade_request(req) @@ -21,12 +24,13 @@ pub(crate) fn is_websocket_request(req: &Request) -> bool { pub(crate) async fn try_upgrade( req: &mut Request, - gateway_origin: Arc, + gateway_origin: Arc, ) -> Result>, Error> { let (res, websocket) = hyper_tungstenite::upgrade(req, None) .map_err(|e| Error::BadRequest(format!("Error upgrading to websocket: {}", e)))?; + let gateway_addr = uri_to_addr(&gateway_origin).ok_or(Error::InternalServerError)?; tokio::spawn(async move { - if let Err(e) = serve_websocket(websocket, gateway_origin.as_str()).await { + if let Err(e) = serve_websocket(websocket, gateway_addr).await { eprintln!("Error in websocket connection: {e}"); } }); @@ -38,10 +42,9 @@ pub(crate) async fn try_upgrade( /// Stream WebSocket frames from the client to the gateway server's TCP socket and vice versa. async fn serve_websocket( websocket: HyperWebsocket, - gateway_origin: &str, + gateway_addr: SocketAddr, ) -> Result<(), Box> { - let addr = gateway_origin.trim_start_matches("https://").trim_start_matches("http://"); - let mut tcp_stream = tokio::net::TcpStream::connect(addr).await?; + let mut tcp_stream = tokio::net::TcpStream::connect(gateway_addr).await?; let mut ws_io = WsIo::new(websocket.await?); let (_, _) = tokio::io::copy_bidirectional(&mut ws_io, &mut tcp_stream).await?; Ok(()) diff --git a/src/error.rs b/src/error.rs index cff9621..1dc79ee 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,8 @@ pub(crate) enum Error { UnsupportedMediaType, BadRequest(String), NotFound, + #[allow(clippy::enum_variant_names)] + InternalServerError, } impl Error { @@ -26,6 +28,7 @@ impl Error { *res.body_mut() = full(e.to_string()).boxed(); } Self::NotFound => *res.status_mut() = StatusCode::NOT_FOUND, + Self::InternalServerError => *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR, }; res @@ -40,6 +43,7 @@ impl std::fmt::Display for Error { Self::MethodNotAllowed => write!(f, "Method not allowed"), Self::BadRequest(e) => write!(f, "Bad request: {}", e), Self::NotFound => write!(f, "Not found"), + Self::InternalServerError => write!(f, "Internal server error"), } } } diff --git a/src/lib.rs b/src/lib.rs index 17bd4f8..28505d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,8 @@ -use std::net::SocketAddr; +use std::net::{SocketAddr, ToSocketAddrs}; use std::sync::Arc; +use http::uri::PathAndQuery; +use http::Uri; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty, Full}; use hyper::body::{Bytes, Incoming}; @@ -24,13 +26,13 @@ pub mod bootstrap; pub const DEFAULT_PORT: u16 = 3000; pub static OHTTP_RELAY_HOST: Lazy = - Lazy::new(|| HeaderValue::from_str("localhost").expect("Invalid HeaderValue")); + Lazy::new(|| HeaderValue::from_str("0.0.0.0").expect("Invalid HeaderValue")); pub static EXPECTED_MEDIA_TYPE: Lazy = Lazy::new(|| HeaderValue::from_str("message/ohttp-req").expect("Invalid HeaderValue")); pub async fn listen_tcp( port: u16, - gateway_origin: String, + gateway_origin: Uri, ) -> Result<(), Box> { let addr = SocketAddr::from(([127, 0, 0, 1], port)); let listener = TcpListener::bind(addr).await?; @@ -40,7 +42,7 @@ pub async fn listen_tcp( pub async fn listen_socket( socket_path: &str, - gateway_origin: String, + gateway_origin: Uri, ) -> Result<(), Box> { let listener = UnixListener::bind(socket_path)?; println!("OHTTP relay listening on socket: {}", socket_path); @@ -49,7 +51,7 @@ pub async fn listen_socket( async fn ohttp_relay( mut listener: L, - gateway_origin: String, + gateway_origin: Uri, ) -> Result<(), Box> where L: Listener + Unpin, @@ -79,11 +81,11 @@ where async fn serve_ohttp_relay( req: Request, - gateway_origin: Arc, + gateway_origin: Arc, ) -> Result>, hyper::Error> { println!("req: {:?}", req); let res = match req.method() { - &Method::POST => handle_ohttp_relay(req, gateway_origin.as_str()).await, + &Method::POST => handle_ohttp_relay(req, &gateway_origin).await, #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] &Method::CONNECT | &Method::GET => crate::bootstrap::handle_ohttp_keys(req, gateway_origin).await, @@ -95,7 +97,7 @@ async fn serve_ohttp_relay( async fn handle_ohttp_relay( req: Request, - gateway_origin: &str, + gateway_origin: &Uri, ) -> Result>, Error> { let fwd_req = into_forward_req(req, gateway_origin)?; forward_request(fwd_req).await.map(|res| { @@ -108,7 +110,7 @@ async fn handle_ohttp_relay( /// Convert an incoming request into a request to forward to the target gateway server. fn into_forward_req( mut req: Request, - gateway_origin: &str, + gateway_origin: &Uri, ) -> Result, Error> { if req.method() != hyper::Method::POST { return Err(Error::MethodNotAllowed); @@ -124,13 +126,17 @@ fn into_forward_req( req.headers_mut().insert(CONTENT_LENGTH, content_length); } - let uri_string = format!( - "{}{}", - gateway_origin, - req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("/") - ); - let uri = uri_string.parse().map_err(|_| Error::BadRequest("Invalid target uri".to_owned()))?; - *req.uri_mut() = uri; + let req_path_and_query = + req.uri().path_and_query().map_or_else(|| PathAndQuery::from_static("/"), |pq| pq.clone()); + + *req.uri_mut() = Uri::builder() + .scheme(gateway_origin.scheme_str().unwrap_or("https")) + .authority( + gateway_origin.authority().expect("Gateway origin must have an authority").as_str(), + ) + .path_and_query(req_path_and_query.as_str()) + .build() + .map_err(|_| Error::BadRequest("Invalid target uri".to_owned()))?; Ok(req) } @@ -141,6 +147,21 @@ async fn forward_request(req: Request) -> Result, E client.request(req).await.map_err(|_| Error::BadGateway) } +pub(crate) fn uri_to_addr(uri: &Uri) -> Option { + let authority = uri.authority()?.as_str(); + let parts: Vec<&str> = authority.split(':').collect(); + let host = parts.first()?; + let port = parts.get(1).and_then(|p| p.parse::().ok()); + + let default_port = match uri.scheme_str() { + Some("https") => 443, + _ => 80, // Default to 80 if it's not https or if the scheme is not specified + }; + + let addr_str = format!("{}:{}", host, port.unwrap_or(default_port)); + addr_str.to_socket_addrs().ok()?.next() +} + pub(crate) fn empty() -> BoxBody { Empty::::new().map_err(|never| match never {}).boxed() } diff --git a/src/main.rs b/src/main.rs index 13633f7..b444a62 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,14 @@ +use std::str::FromStr; + +use http::Uri; use ohttp_relay::DEFAULT_PORT; #[tokio::main] async fn main() -> Result<(), Box> { let port_env = std::env::var("PORT"); let unix_socket_env = std::env::var("UNIX_SOCKET"); - let gateway_origin = std::env::var("GATEWAY_ORIGIN").expect("GATEWAY_ORIGIN is required"); + let gateway_origin_str = std::env::var("GATEWAY_ORIGIN").expect("GATEWAY_ORIGIN is required"); + let gateway_origin = Uri::from_str(&gateway_origin_str).expect("Invalid GATEWAY_ORIGIN URI"); match (port_env, unix_socket_env) { (Ok(_), Ok(_)) => panic!( diff --git a/tests/integration.rs b/tests/integration.rs index fa0ea90..90ae6c8 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,8 +1,10 @@ #[cfg(test)] mod integration { use std::net::SocketAddr; + use std::str::FromStr; use hex::FromHex; + use http::Uri; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full}; use hyper::body::{Bytes, Incoming}; @@ -24,12 +26,13 @@ mod integration { #[tokio::test] async fn test_request_response() { let gateway_port = find_free_port(); + let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); let relay_port = find_free_port(); tokio::select! { _ = example_gateway_http(gateway_port) => { assert!(false, "Gateway is long running"); } - _ = listen_tcp(relay_port, format!("http://localhost:{}", gateway_port)) => { + _ = listen_tcp(relay_port, gateway) => { assert!(false, "Relay is long running"); } _ = ohttp_req_over_tcp(relay_port) => {} @@ -46,12 +49,13 @@ mod integration { } let gateway_port = find_free_port(); + let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); let socket_path_str = socket_path.to_str().unwrap(); tokio::select! { _ = example_gateway_http(gateway_port) => { assert!(false, "Gateway is long running"); } - _ = listen_socket(socket_path_str, format!("http://localhost:{}", gateway_port)) => { + _ = listen_socket(socket_path_str, gateway) => { assert!(false, "Relay is long running"); } _ = ohttp_req_over_unix_socket(socket_path_str) => {} @@ -199,13 +203,13 @@ mod integration { .with_root_certificates(root_store) .with_no_client_auth(); - let (ws_stream, _res) = connect_async(format!("ws://localhost:{}", relay_port)) + let (ws_stream, _res) = connect_async(format!("ws://0.0.0.0:{}", relay_port)) .await .expect("Failed to connect"); println!("Connected to ws"); let ws_io = WsIo::new(ws_stream); let connector = TlsConnector::from(Arc::new(config)); - let domain = pki_types::ServerName::try_from("localhost") + let domain = pki_types::ServerName::try_from("0.0.0.0") .map_err(|_| { std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid dnsname") }) @@ -214,7 +218,7 @@ mod integration { let mut tls_stream = connector.connect(domain, ws_io).await.unwrap(); let content = - b"GET /ohttp-keys HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; + b"GET /ohttp-keys HTTP/1.1\r\nHost: 0.0.0.0\r\nConnection: close\r\n\r\n"; tls_stream.write_all(content).await.unwrap(); tls_stream.flush().await.unwrap(); let mut plaintext = Vec::new(); @@ -248,12 +252,12 @@ mod integration { .with_root_certificates(root_store) .with_no_client_auth(); let proxy = - ureq::Proxy::new(format!("http://localhost:{}", relay_port).as_str()).unwrap(); + ureq::Proxy::new(format!("http://0.0.0.0:{}", relay_port).as_str()).unwrap(); let https = ureq::AgentBuilder::new().tls_config(Arc::new(config)).proxy(proxy).build(); let res = tokio::task::spawn_blocking(move || { https - .get(format!("https://localhost:{}/ohttp-keys", gateway_port).as_str()) + .get(format!("https://0.0.0.0:{}/ohttp-keys", gateway_port).as_str()) .call() .unwrap() }) @@ -270,6 +274,7 @@ mod integration { F: FnOnce(u16, u16, CertificateDer<'static>) -> Pin>>, { let gateway_port = find_free_port(); + let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); let relay_port = find_free_port(); let (key, cert) = gen_localhost_cert(); let cert_clone = cert.clone(); @@ -277,7 +282,7 @@ mod integration { _ = example_gateway_https(gateway_port, (key, cert)) => { assert!(false, "Gateway is long running"); } - _ = listen_tcp(relay_port, format!("http://localhost:{}", gateway_port)) => { + _ = listen_tcp(relay_port, gateway) => { assert!(false, "Relay is long running"); } _ = client_fn(relay_port, gateway_port, cert_clone) => {} @@ -327,7 +332,7 @@ mod integration { } fn gen_localhost_cert() -> (PrivateKeyDer<'static>, CertificateDer<'static>) { - let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap(); + let cert = rcgen::generate_simple_self_signed(vec!["0.0.0.0".to_string()]).unwrap(); let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.serialize_private_key_der())); let cert = CertificateDer::from(cert.serialize_der().unwrap());