Skip to content

Commit

Permalink
Try parse gateway_origin as Uri
Browse files Browse the repository at this point in the history
... in order to parse SocketAddr out of it more reliably.

0.0.0.0 replaces localhost to translate between Uri and SocketAddr.
  • Loading branch information
DanGould committed Feb 20, 2024
1 parent a28c9bd commit 21140a4
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 48 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"]

[dependencies]
futures = { version = "0.3", optional = true }
http = "1"
http-body-util = "0.1"
hyper = { version = "1", features = ["http1", "server"] }
hyper-rustls = { version = "0.26", features = ["webpki-roots"] }
Expand Down
32 changes: 16 additions & 16 deletions src/bootstrap/connect.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
use std::net::SocketAddr;
use std::sync::Arc;

use http::Uri;
use http_body_util::combinators::BoxBody;
use hyper::body::{Bytes, Incoming};
use hyper::upgrade::Upgraded;
use hyper::{Method, Request, Response};
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;

use crate::empty;
use crate::error::Error;
use crate::{empty, uri_to_addr};

pub(crate) fn is_connect_request(req: &Request<Incoming>) -> bool {
Method::CONNECT == req.method()
}

pub(crate) async fn try_upgrade(
req: Request<Incoming>,
gateway_origin: Arc<String>,
gateway_origin: Arc<Uri>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
if let Some(addr) = find_allowable_gateway(&req, &gateway_origin) {
tokio::task::spawn(async move {
Expand All @@ -38,7 +40,7 @@ pub(crate) async fn try_upgrade(

/// Create a TCP connection to host:port, build a tunnel between the connection and
/// the upgraded connection
async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> {
async fn tunnel(upgraded: Upgraded, addr: SocketAddr) -> std::io::Result<()> {
let mut server = TcpStream::connect(addr).await?;
let mut upgraded = TokioIo::new(upgraded);
let (_, _) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?;
Expand All @@ -48,36 +50,34 @@ async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> {
/// Only allow CONNECT requests to the configured OHTTP gateway authority.
/// This prevents the relay from being used as an arbitrary proxy
/// to any host on the internet.
fn find_allowable_gateway<B>(req: &Request<B>, gateway_origin: &str) -> Option<String> {
let gateway_authority =
gateway_origin.trim_start_matches("https://").trim_start_matches("http://");
let req_gateway = req.uri().authority().map(|auth| auth.to_string());
if req_gateway == Some(gateway_authority.to_string()) {
req_gateway
} else {
None
fn find_allowable_gateway<B>(req: &Request<B>, gateway_origin: &Uri) -> Option<SocketAddr> {
if req.uri().authority() != gateway_origin.authority() {
return None;
}

uri_to_addr(gateway_origin)
}

#[cfg(test)]
mod test {
use hyper::Request;
use once_cell::sync::Lazy;

use super::*;

static GATEWAY_ORIGIN: Lazy<Uri> = Lazy::new(|| Uri::from_static("https://gateway.com"));

#[test]
fn mismatched_gateways_not_allowed() {
let gateway_origin = "https://gateway.com";
let not_gateway_origin = "https://not-gateway.com";
let req = hyper::Request::builder().uri(not_gateway_origin).body(()).unwrap();
let allowable_gateway = find_allowable_gateway(&req, gateway_origin);
let allowable_gateway = find_allowable_gateway(&req, &*GATEWAY_ORIGIN);
assert!(allowable_gateway.is_none());
}

#[test]
fn matched_gateways_allowed() {
let gateway_origin = "https://gateway.com";
let req = Request::builder().uri(gateway_origin).body(()).unwrap();
assert!(find_allowable_gateway(&req, gateway_origin).is_some());
let req = Request::builder().uri(&*GATEWAY_ORIGIN).body(()).unwrap();
assert!(find_allowable_gateway(&req, &*GATEWAY_ORIGIN).is_some());
}
}
3 changes: 2 additions & 1 deletion src/bootstrap/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::sync::Arc;

use http::Uri;
use http_body_util::combinators::BoxBody;
use hyper::body::{Bytes, Incoming};
use hyper::{Request, Response};
Expand All @@ -14,7 +15,7 @@ pub mod ws;

pub(crate) async fn handle_ohttp_keys(
mut req: Request<Incoming>,
gateway_origin: Arc<String>,
gateway_origin: Arc<Uri>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
#[cfg(feature = "connect-bootstrap")]
if connect::is_connect_request(&req) {
Expand Down
13 changes: 8 additions & 5 deletions src/bootstrap/ws.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures::{Sink, SinkExt, StreamExt};
use http::Uri;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::body::{Bytes, Incoming};
Expand All @@ -14,19 +16,21 @@ use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::{tungstenite, WebSocketStream};

use crate::error::Error;
use crate::uri_to_addr;

pub(crate) fn is_websocket_request(req: &Request<Incoming>) -> bool {
hyper_tungstenite::is_upgrade_request(req)
}

pub(crate) async fn try_upgrade(
req: &mut Request<Incoming>,
gateway_origin: Arc<String>,
gateway_origin: Arc<Uri>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
let (res, websocket) = hyper_tungstenite::upgrade(req, None)
.map_err(|e| Error::BadRequest(format!("Error upgrading to websocket: {}", e)))?;
let gateway_addr = uri_to_addr(&gateway_origin).ok_or(Error::InternalServerError)?;
tokio::spawn(async move {
if let Err(e) = serve_websocket(websocket, gateway_origin.as_str()).await {
if let Err(e) = serve_websocket(websocket, gateway_addr).await {
eprintln!("Error in websocket connection: {e}");
}
});
Expand All @@ -38,10 +42,9 @@ pub(crate) async fn try_upgrade(
/// Stream WebSocket frames from the client to the gateway server's TCP socket and vice versa.
async fn serve_websocket(
websocket: HyperWebsocket,
gateway_origin: &str,
gateway_addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let addr = gateway_origin.trim_start_matches("https://").trim_start_matches("http://");
let mut tcp_stream = tokio::net::TcpStream::connect(addr).await?;
let mut tcp_stream = tokio::net::TcpStream::connect(gateway_addr).await?;
let mut ws_io = WsIo::new(websocket.await?);
let (_, _) = tokio::io::copy_bidirectional(&mut ws_io, &mut tcp_stream).await?;
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub(crate) enum Error {
UnsupportedMediaType,
BadRequest(String),
NotFound,
#[allow(clippy::enum_variant_names)]
InternalServerError,
}

impl Error {
Expand All @@ -26,6 +28,7 @@ impl Error {
*res.body_mut() = full(e.to_string()).boxed();
}
Self::NotFound => *res.status_mut() = StatusCode::NOT_FOUND,
Self::InternalServerError => *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR,
};

res
Expand All @@ -40,6 +43,7 @@ impl std::fmt::Display for Error {
Self::MethodNotAllowed => write!(f, "Method not allowed"),
Self::BadRequest(e) => write!(f, "Bad request: {}", e),
Self::NotFound => write!(f, "Not found"),
Self::InternalServerError => write!(f, "Internal server error"),
}
}
}
Expand Down
53 changes: 37 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::net::SocketAddr;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;

use http::uri::PathAndQuery;
use http::Uri;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty, Full};
use hyper::body::{Bytes, Incoming};
Expand All @@ -24,13 +26,13 @@ pub mod bootstrap;

pub const DEFAULT_PORT: u16 = 3000;
pub static OHTTP_RELAY_HOST: Lazy<HeaderValue> =
Lazy::new(|| HeaderValue::from_str("localhost").expect("Invalid HeaderValue"));
Lazy::new(|| HeaderValue::from_str("0.0.0.0").expect("Invalid HeaderValue"));
pub static EXPECTED_MEDIA_TYPE: Lazy<HeaderValue> =
Lazy::new(|| HeaderValue::from_str("message/ohttp-req").expect("Invalid HeaderValue"));

pub async fn listen_tcp(
port: u16,
gateway_origin: String,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = TcpListener::bind(addr).await?;
Expand All @@ -40,7 +42,7 @@ pub async fn listen_tcp(

pub async fn listen_socket(
socket_path: &str,
gateway_origin: String,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = UnixListener::bind(socket_path)?;
println!("OHTTP relay listening on socket: {}", socket_path);
Expand All @@ -49,7 +51,7 @@ pub async fn listen_socket(

async fn ohttp_relay<L>(
mut listener: L,
gateway_origin: String,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
L: Listener + Unpin,
Expand Down Expand Up @@ -79,11 +81,11 @@ where

async fn serve_ohttp_relay(
req: Request<Incoming>,
gateway_origin: Arc<String>,
gateway_origin: Arc<Uri>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
println!("req: {:?}", req);
let res = match req.method() {
&Method::POST => handle_ohttp_relay(req, gateway_origin.as_str()).await,
&Method::POST => handle_ohttp_relay(req, &gateway_origin).await,
#[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))]
&Method::CONNECT | &Method::GET =>
crate::bootstrap::handle_ohttp_keys(req, gateway_origin).await,
Expand All @@ -95,7 +97,7 @@ async fn serve_ohttp_relay(

async fn handle_ohttp_relay(
req: Request<Incoming>,
gateway_origin: &str,
gateway_origin: &Uri,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
let fwd_req = into_forward_req(req, gateway_origin)?;
forward_request(fwd_req).await.map(|res| {
Expand All @@ -108,7 +110,7 @@ async fn handle_ohttp_relay(
/// Convert an incoming request into a request to forward to the target gateway server.
fn into_forward_req(
mut req: Request<Incoming>,
gateway_origin: &str,
gateway_origin: &Uri,
) -> Result<Request<Incoming>, Error> {
if req.method() != hyper::Method::POST {
return Err(Error::MethodNotAllowed);
Expand All @@ -124,13 +126,17 @@ fn into_forward_req(
req.headers_mut().insert(CONTENT_LENGTH, content_length);
}

let uri_string = format!(
"{}{}",
gateway_origin,
req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("/")
);
let uri = uri_string.parse().map_err(|_| Error::BadRequest("Invalid target uri".to_owned()))?;
*req.uri_mut() = uri;
let req_path_and_query =
req.uri().path_and_query().map_or_else(|| PathAndQuery::from_static("/"), |pq| pq.clone());

*req.uri_mut() = Uri::builder()
.scheme(gateway_origin.scheme_str().unwrap_or("https"))
.authority(
gateway_origin.authority().expect("Gateway origin must have an authority").as_str(),
)
.path_and_query(req_path_and_query.as_str())
.build()
.map_err(|_| Error::BadRequest("Invalid target uri".to_owned()))?;
Ok(req)
}

Expand All @@ -141,6 +147,21 @@ async fn forward_request(req: Request<Incoming>) -> Result<Response<Incoming>, E
client.request(req).await.map_err(|_| Error::BadGateway)
}

pub(crate) fn uri_to_addr(uri: &Uri) -> Option<SocketAddr> {
let authority = uri.authority()?.as_str();
let parts: Vec<&str> = authority.split(':').collect();
let host = parts.first()?;
let port = parts.get(1).and_then(|p| p.parse::<u16>().ok());

let default_port = match uri.scheme_str() {
Some("https") => 443,
_ => 80, // Default to 80 if it's not https or if the scheme is not specified
};

let addr_str = format!("{}:{}", host, port.unwrap_or(default_port));
addr_str.to_socket_addrs().ok()?.next()
}

pub(crate) fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new().map_err(|never| match never {}).boxed()
}
Expand Down
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use std::str::FromStr;

use http::Uri;
use ohttp_relay::DEFAULT_PORT;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let port_env = std::env::var("PORT");
let unix_socket_env = std::env::var("UNIX_SOCKET");
let gateway_origin = std::env::var("GATEWAY_ORIGIN").expect("GATEWAY_ORIGIN is required");
let gateway_origin_str = std::env::var("GATEWAY_ORIGIN").expect("GATEWAY_ORIGIN is required");
let gateway_origin = Uri::from_str(&gateway_origin_str).expect("Invalid GATEWAY_ORIGIN URI");

match (port_env, unix_socket_env) {
(Ok(_), Ok(_)) => panic!(
Expand Down
Loading

0 comments on commit 21140a4

Please sign in to comment.