Skip to content

Commit

Permalink
Add listen_tcp_on_free_port to return a test port
Browse files Browse the repository at this point in the history
Previously in tests downstream ohttp_relay was initiated with a port
that may no longer be free by the time it got bound. By having this
code bind on and return the port the indirection is removed.
  • Loading branch information
DanGould committed Dec 30, 2024
1 parent 59e8a00 commit 9cbec68
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ default = ["bootstrap"]
bootstrap = ["connect-bootstrap", "ws-bootstrap"]
connect-bootstrap = []
ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"]
_test-util = []

[dependencies]
futures = { version = "0.3", optional = true }
Expand Down
65 changes: 44 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,70 @@ pub static EXPECTED_MEDIA_TYPE: Lazy<HeaderValue> =
Lazy::new(|| HeaderValue::from_str("message/ohttp-req").expect("Invalid HeaderValue"));

#[instrument]
pub async fn listen_tcp(port: u16, gateway_origin: Uri) -> Result<(), BoxError> {
pub async fn listen_tcp(
port: u16,
gateway_origin: Uri,
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = TcpListener::bind(addr).await?;
println!("OHTTP relay listening on tcp://{}", addr);
ohttp_relay(listener, gateway_origin).await
}

#[instrument]
pub async fn listen_socket(socket_path: &str, gateway_origin: Uri) -> Result<(), BoxError> {
pub async fn listen_socket(
socket_path: &str,
gateway_origin: Uri,
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let listener = UnixListener::bind(socket_path)?;
info!("OHTTP relay listening on socket: {}", socket_path);
ohttp_relay(listener, gateway_origin).await
}

#[cfg(feature = "_test-util")]
pub async fn listen_tcp_on_free_port(
gateway_origin: Uri,
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
let listener = tokio::net::TcpListener::bind("[::]:0").await?;
let port = listener.local_addr()?.port();
println!("Directory server binding to port {}", listener.local_addr()?);
let handle = ohttp_relay(listener, gateway_origin).await?;
Ok((port, handle))
}

#[instrument(skip(listener))]
async fn ohttp_relay<L>(mut listener: L, gateway_origin: Uri) -> Result<(), BoxError>
async fn ohttp_relay<L>(
mut listener: L,
gateway_origin: Uri,
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError>
where
L: Listener + Unpin,
L: Listener + Unpin + Send + 'static,
L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let gateway_origin = GatewayUri::new(gateway_origin)?;
let gateway_origin: Arc<GatewayUri> = Arc::new(gateway_origin);

while let Ok((stream, _)) = listener.accept().await {
let gateway_origin = gateway_origin.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
let handle = tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let gateway_origin = gateway_origin.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
Ok(())
});

Ok(())
Ok(handle)
}

#[instrument]
Expand Down
3 changes: 1 addition & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
}
(Err(_), Err(_)) => ohttp_relay::listen_tcp(DEFAULT_PORT, gateway_origin).await?,
}

Ok(())
.await?
}

fn init_tracing() {
Expand Down
30 changes: 25 additions & 5 deletions tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#[cfg(test)]
#[cfg(feature = "_test-util")]
mod integration {
use std::fs::File;
use std::io::Read;
Expand Down Expand Up @@ -34,7 +35,13 @@ mod integration {
async fn test_request_response_tcp() {
let gateway_port = find_free_port();
let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap();
let relay_port = find_free_port();
let (relay_port, relay_handle) =
listen_tcp_on_free_port(gateway).await.expect("Failed to listen on free port");
let relay_task = tokio::spawn(async move {
if let Err(e) = relay_handle.await {
eprintln!("Relay failed: {}", e);
}
});
let n_http_port = find_free_port();
let n_https_port = find_free_port();
let nginx_cert = gen_localhost_cert();
Expand All @@ -46,7 +53,7 @@ mod integration {
_ = example_gateway_http(gateway_port) => {
assert!(false, "Gateway is long running");
}
_ = listen_tcp(relay_port, gateway) => {
_ = relay_task => {
assert!(false, "Relay is long running");
}
_ = ohttp_req(n_https_port, nginx_cert_der) => {}
Expand All @@ -67,6 +74,13 @@ mod integration {
let nginx_cert = gen_localhost_cert();
let nginx_cert_der = cert_to_cert_der(&nginx_cert);
let socket_path_str = socket_path.to_str().unwrap();
let relay_handle =
listen_socket(socket_path_str, gateway).await.expect("Failed to listen on socket");
let relay_task = tokio::spawn(async move {
if let Err(e) = relay_handle.await {
eprintln!("Relay failed: {}", e);
}
});
let n_http_port = find_free_port();
let n_https_port = find_free_port();
let _nginx =
Expand All @@ -76,7 +90,7 @@ mod integration {
_ = example_gateway_http(gateway_port) => {
assert!(false, "Gateway is long running");
}
_ = listen_socket(socket_path_str, gateway) => {
_ = relay_task => {
assert!(false, "Relay is long running");
}
_ = ohttp_req(n_https_port, nginx_cert_der) => {}
Expand Down Expand Up @@ -286,10 +300,16 @@ mod integration {
{
let gateway_port = find_free_port();
let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap();
let relay_port = find_free_port();
let nginx_cert = gen_localhost_cert();
let gateway_cert = gen_localhost_cert();
let gateway_cert_der = cert_to_cert_der(&gateway_cert);
let (relay_port, relay_handle) =
listen_tcp_on_free_port(gateway).await.expect("Failed to listen on free port");
let relay_task = tokio::spawn(async move {
if let Err(e) = relay_handle.await {
eprintln!("Relay failed: {}", e);
}
});
let n_http_port = find_free_port();
let n_https_port = find_free_port();
let _nginx = start_nginx(
Expand All @@ -303,7 +323,7 @@ mod integration {
_ = example_gateway_https(gateway_port, gateway_cert) => {
assert!(false, "Gateway is long running");
}
_ = listen_tcp(relay_port, gateway) => {
_ = relay_task => {
assert!(false, "Relay is long running");
}
_ = client_fn(n_http_port, gateway_port, gateway_cert_der) => {}
Expand Down

0 comments on commit 9cbec68

Please sign in to comment.