Skip to content

Commit

Permalink
Merge pull request #19 from DanGould/uri
Browse files Browse the repository at this point in the history
Try parse gateway_origin as Uri
  • Loading branch information
DanGould authored Feb 20, 2024
2 parents a28c9bd + 21140a4 commit 44d6bee
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 48 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"]

[dependencies]
futures = { version = "0.3", optional = true }
http = "1"
http-body-util = "0.1"
hyper = { version = "1", features = ["http1", "server"] }
hyper-rustls = { version = "0.26", features = ["webpki-roots"] }
Expand Down
32 changes: 16 additions & 16 deletions src/bootstrap/connect.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
use std::net::SocketAddr;
use std::sync::Arc;

use http::Uri;
use http_body_util::combinators::BoxBody;
use hyper::body::{Bytes, Incoming};
use hyper::upgrade::Upgraded;
use hyper::{Method, Request, Response};
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;

use crate::empty;
use crate::error::Error;
use crate::{empty, uri_to_addr};

pub(crate) fn is_connect_request(req: &Request<Incoming>) -> bool {
Method::CONNECT == req.method()
}

pub(crate) async fn try_upgrade(
req: Request<Incoming>,
gateway_origin: Arc<String>,
gateway_origin: Arc<Uri>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
if let Some(addr) = find_allowable_gateway(&req, &gateway_origin) {
tokio::task::spawn(async move {
Expand All @@ -38,7 +40,7 @@ pub(crate) async fn try_upgrade(

/// Create a TCP connection to host:port, build a tunnel between the connection and
/// the upgraded connection
async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> {
async fn tunnel(upgraded: Upgraded, addr: SocketAddr) -> std::io::Result<()> {
let mut server = TcpStream::connect(addr).await?;
let mut upgraded = TokioIo::new(upgraded);
let (_, _) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?;
Expand All @@ -48,36 +50,34 @@ async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> {
/// Only allow CONNECT requests to the configured OHTTP gateway authority.
/// This prevents the relay from being used as an arbitrary proxy
/// to any host on the internet.
fn find_allowable_gateway<B>(req: &Request<B>, gateway_origin: &str) -> Option<String> {
let gateway_authority =
gateway_origin.trim_start_matches("https://").trim_start_matches("http://");
let req_gateway = req.uri().authority().map(|auth| auth.to_string());
if req_gateway == Some(gateway_authority.to_string()) {
req_gateway
} else {
None
fn find_allowable_gateway<B>(req: &Request<B>, gateway_origin: &Uri) -> Option<SocketAddr> {
if req.uri().authority() != gateway_origin.authority() {
return None;
}

uri_to_addr(gateway_origin)
}

#[cfg(test)]
mod test {
use hyper::Request;
use once_cell::sync::Lazy;

use super::*;

static GATEWAY_ORIGIN: Lazy<Uri> = Lazy::new(|| Uri::from_static("https://gateway.com"));

#[test]
fn mismatched_gateways_not_allowed() {
let gateway_origin = "https://gateway.com";
let not_gateway_origin = "https://not-gateway.com";
let req = hyper::Request::builder().uri(not_gateway_origin).body(()).unwrap();
let allowable_gateway = find_allowable_gateway(&req, gateway_origin);
let allowable_gateway = find_allowable_gateway(&req, &*GATEWAY_ORIGIN);
assert!(allowable_gateway.is_none());
}

#[test]
fn matched_gateways_allowed() {
let gateway_origin = "https://gateway.com";
let req = Request::builder().uri(gateway_origin).body(()).unwrap();
assert!(find_allowable_gateway(&req, gateway_origin).is_some());
let req = Request::builder().uri(&*GATEWAY_ORIGIN).body(()).unwrap();
assert!(find_allowable_gateway(&req, &*GATEWAY_ORIGIN).is_some());
}
}
3 changes: 2 additions & 1 deletion src/bootstrap/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::sync::Arc;

use http::Uri;
use http_body_util::combinators::BoxBody;
use hyper::body::{Bytes, Incoming};
use hyper::{Request, Response};
Expand All @@ -14,7 +15,7 @@ pub mod ws;

pub(crate) async fn handle_ohttp_keys(
mut req: Request<Incoming>,
gateway_origin: Arc<String>,
gateway_origin: Arc<Uri>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
#[cfg(feature = "connect-bootstrap")]
if connect::is_connect_request(&req) {
Expand Down
13 changes: 8 additions & 5 deletions src/bootstrap/ws.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures::{Sink, SinkExt, StreamExt};
use http::Uri;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::body::{Bytes, Incoming};
Expand All @@ -14,19 +16,21 @@ use tokio_tungstenite::tungstenite::protocol::Message;
use tokio_tungstenite::{tungstenite, WebSocketStream};

use crate::error::Error;
use crate::uri_to_addr;

pub(crate) fn is_websocket_request(req: &Request<Incoming>) -> bool {
hyper_tungstenite::is_upgrade_request(req)
}

pub(crate) async fn try_upgrade(
req: &mut Request<Incoming>,
gateway_origin: Arc<String>,
gateway_origin: Arc<Uri>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
let (res, websocket) = hyper_tungstenite::upgrade(req, None)
.map_err(|e| Error::BadRequest(format!("Error upgrading to websocket: {}", e)))?;
let gateway_addr = uri_to_addr(&gateway_origin).ok_or(Error::InternalServerError)?;
tokio::spawn(async move {
if let Err(e) = serve_websocket(websocket, gateway_origin.as_str()).await {
if let Err(e) = serve_websocket(websocket, gateway_addr).await {
eprintln!("Error in websocket connection: {e}");
}
});
Expand All @@ -38,10 +42,9 @@ pub(crate) async fn try_upgrade(
/// Stream WebSocket frames from the client to the gateway server's TCP socket and vice versa.
async fn serve_websocket(
websocket: HyperWebsocket,
gateway_origin: &str,
gateway_addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let addr = gateway_origin.trim_start_matches("https://").trim_start_matches("http://");
let mut tcp_stream = tokio::net::TcpStream::connect(addr).await?;
let mut tcp_stream = tokio::net::TcpStream::connect(gateway_addr).await?;
let mut ws_io = WsIo::new(websocket.await?);
let (_, _) = tokio::io::copy_bidirectional(&mut ws_io, &mut tcp_stream).await?;
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub(crate) enum Error {
UnsupportedMediaType,
BadRequest(String),
NotFound,
#[allow(clippy::enum_variant_names)]
InternalServerError,
}

impl Error {
Expand All @@ -26,6 +28,7 @@ impl Error {
*res.body_mut() = full(e.to_string()).boxed();
}
Self::NotFound => *res.status_mut() = StatusCode::NOT_FOUND,
Self::InternalServerError => *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR,
};

res
Expand All @@ -40,6 +43,7 @@ impl std::fmt::Display for Error {
Self::MethodNotAllowed => write!(f, "Method not allowed"),
Self::BadRequest(e) => write!(f, "Bad request: {}", e),
Self::NotFound => write!(f, "Not found"),
Self::InternalServerError => write!(f, "Internal server error"),
}
}
}
Expand Down
53 changes: 37 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::net::SocketAddr;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;

use http::uri::PathAndQuery;
use http::Uri;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty, Full};
use hyper::body::{Bytes, Incoming};
Expand All @@ -24,13 +26,13 @@ pub mod bootstrap;

pub const DEFAULT_PORT: u16 = 3000;
pub static OHTTP_RELAY_HOST: Lazy<HeaderValue> =
Lazy::new(|| HeaderValue::from_str("localhost").expect("Invalid HeaderValue"));
Lazy::new(|| HeaderValue::from_str("0.0.0.0").expect("Invalid HeaderValue"));
pub static EXPECTED_MEDIA_TYPE: Lazy<HeaderValue> =
Lazy::new(|| HeaderValue::from_str("message/ohttp-req").expect("Invalid HeaderValue"));

pub async fn listen_tcp(
port: u16,
gateway_origin: String,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = TcpListener::bind(addr).await?;
Expand All @@ -40,7 +42,7 @@ pub async fn listen_tcp(

pub async fn listen_socket(
socket_path: &str,
gateway_origin: String,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listener = UnixListener::bind(socket_path)?;
println!("OHTTP relay listening on socket: {}", socket_path);
Expand All @@ -49,7 +51,7 @@ pub async fn listen_socket(

async fn ohttp_relay<L>(
mut listener: L,
gateway_origin: String,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
L: Listener + Unpin,
Expand Down Expand Up @@ -79,11 +81,11 @@ where

async fn serve_ohttp_relay(
req: Request<Incoming>,
gateway_origin: Arc<String>,
gateway_origin: Arc<Uri>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
println!("req: {:?}", req);
let res = match req.method() {
&Method::POST => handle_ohttp_relay(req, gateway_origin.as_str()).await,
&Method::POST => handle_ohttp_relay(req, &gateway_origin).await,
#[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))]
&Method::CONNECT | &Method::GET =>
crate::bootstrap::handle_ohttp_keys(req, gateway_origin).await,
Expand All @@ -95,7 +97,7 @@ async fn serve_ohttp_relay(

async fn handle_ohttp_relay(
req: Request<Incoming>,
gateway_origin: &str,
gateway_origin: &Uri,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, Error> {
let fwd_req = into_forward_req(req, gateway_origin)?;
forward_request(fwd_req).await.map(|res| {
Expand All @@ -108,7 +110,7 @@ async fn handle_ohttp_relay(
/// Convert an incoming request into a request to forward to the target gateway server.
fn into_forward_req(
mut req: Request<Incoming>,
gateway_origin: &str,
gateway_origin: &Uri,
) -> Result<Request<Incoming>, Error> {
if req.method() != hyper::Method::POST {
return Err(Error::MethodNotAllowed);
Expand All @@ -124,13 +126,17 @@ fn into_forward_req(
req.headers_mut().insert(CONTENT_LENGTH, content_length);
}

let uri_string = format!(
"{}{}",
gateway_origin,
req.uri().path_and_query().map(|x| x.as_str()).unwrap_or("/")
);
let uri = uri_string.parse().map_err(|_| Error::BadRequest("Invalid target uri".to_owned()))?;
*req.uri_mut() = uri;
let req_path_and_query =
req.uri().path_and_query().map_or_else(|| PathAndQuery::from_static("/"), |pq| pq.clone());

*req.uri_mut() = Uri::builder()
.scheme(gateway_origin.scheme_str().unwrap_or("https"))
.authority(
gateway_origin.authority().expect("Gateway origin must have an authority").as_str(),
)
.path_and_query(req_path_and_query.as_str())
.build()
.map_err(|_| Error::BadRequest("Invalid target uri".to_owned()))?;
Ok(req)
}

Expand All @@ -141,6 +147,21 @@ async fn forward_request(req: Request<Incoming>) -> Result<Response<Incoming>, E
client.request(req).await.map_err(|_| Error::BadGateway)
}

pub(crate) fn uri_to_addr(uri: &Uri) -> Option<SocketAddr> {
let authority = uri.authority()?.as_str();
let parts: Vec<&str> = authority.split(':').collect();
let host = parts.first()?;
let port = parts.get(1).and_then(|p| p.parse::<u16>().ok());

let default_port = match uri.scheme_str() {
Some("https") => 443,
_ => 80, // Default to 80 if it's not https or if the scheme is not specified
};

let addr_str = format!("{}:{}", host, port.unwrap_or(default_port));
addr_str.to_socket_addrs().ok()?.next()
}

pub(crate) fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new().map_err(|never| match never {}).boxed()
}
Expand Down
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use std::str::FromStr;

use http::Uri;
use ohttp_relay::DEFAULT_PORT;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let port_env = std::env::var("PORT");
let unix_socket_env = std::env::var("UNIX_SOCKET");
let gateway_origin = std::env::var("GATEWAY_ORIGIN").expect("GATEWAY_ORIGIN is required");
let gateway_origin_str = std::env::var("GATEWAY_ORIGIN").expect("GATEWAY_ORIGIN is required");
let gateway_origin = Uri::from_str(&gateway_origin_str).expect("Invalid GATEWAY_ORIGIN URI");

match (port_env, unix_socket_env) {
(Ok(_), Ok(_)) => panic!(
Expand Down
Loading

0 comments on commit 44d6bee

Please sign in to comment.