Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add listen_tcp_on_free_port to return a test port #41

Merged
merged 3 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ default = ["bootstrap"]
bootstrap = ["connect-bootstrap", "ws-bootstrap"]
connect-bootstrap = []
ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"]
_test-util = []

[dependencies]
futures = { version = "0.3", optional = true }
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use hyper::{Response, StatusCode};

use crate::{empty, full};

pub(crate) type BoxError = Box<dyn std::error::Error + Send + Sync>;

#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
pub(crate) enum Error {
Expand Down
3 changes: 2 additions & 1 deletion src/gateway_uri.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use http::Uri;

use crate::error::BoxError;
/// A normalized gateway origin URI with a default port if none is specified.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GatewayUri(Uri);

impl GatewayUri {
pub fn new(mut gateway_origin: Uri) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
pub fn new(mut gateway_origin: Uri) -> Result<Self, BoxError> {
let (scheme, default_port) = match gateway_origin.scheme_str() {
Some("http") => ("http", 80),
Some("https") | None => ("https", 443),
Expand Down
58 changes: 36 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use tracing::{debug, error, info, instrument};

pub mod error;
mod gateway_uri;
use crate::error::Error;
use crate::error::{BoxError, Error};

#[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))]
pub mod bootstrap;
Expand All @@ -40,7 +40,7 @@ pub static EXPECTED_MEDIA_TYPE: Lazy<HeaderValue> =
pub async fn listen_tcp(
port: u16,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let addr = SocketAddr::from(([0, 0, 0, 0], port));
let listener = TcpListener::bind(addr).await?;
println!("OHTTP relay listening on tcp://{}", addr);
Expand All @@ -51,42 +51,56 @@ pub async fn listen_tcp(
pub async fn listen_socket(
socket_path: &str,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let listener = UnixListener::bind(socket_path)?;
info!("OHTTP relay listening on socket: {}", socket_path);
ohttp_relay(listener, gateway_origin).await
}

#[cfg(feature = "_test-util")]
pub async fn listen_tcp_on_free_port(
gateway_origin: Uri,
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
let listener = tokio::net::TcpListener::bind("[::]:0").await?;
let port = listener.local_addr()?.port();
println!("Directory server binding to port {}", listener.local_addr()?);
let handle = ohttp_relay(listener, gateway_origin).await?;
Ok((port, handle))
}

#[instrument(skip(listener))]
async fn ohttp_relay<L>(
mut listener: L,
gateway_origin: Uri,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError>
where
L: Listener + Unpin,
L: Listener + Unpin + Send + 'static,
L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let gateway_origin = GatewayUri::new(gateway_origin)?;
let gateway_origin: Arc<GatewayUri> = Arc::new(gateway_origin);

while let Ok((stream, _)) = listener.accept().await {
let gateway_origin = gateway_origin.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
let handle = tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let gateway_origin = gateway_origin.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| serve_ohttp_relay(req, gateway_origin.clone())),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
Ok(())
});

Ok(())
Ok(handle)
}

#[instrument]
Expand Down
3 changes: 1 addition & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
}
(Err(_), Err(_)) => ohttp_relay::listen_tcp(DEFAULT_PORT, gateway_origin).await?,
}

Ok(())
.await?
}

fn init_tracing() {
Expand Down
39 changes: 34 additions & 5 deletions tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#[cfg(test)]
#[cfg(feature = "_test-util")]
mod integration {
use std::fs::File;
use std::io::Read;
Expand Down Expand Up @@ -34,7 +35,13 @@ mod integration {
async fn test_request_response_tcp() {
let gateway_port = find_free_port();
let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap();
let relay_port = find_free_port();
let (relay_port, relay_handle) =
listen_tcp_on_free_port(gateway).await.expect("Failed to listen on free port");
let relay_task = tokio::spawn(async move {
if let Err(e) = relay_handle.await {
eprintln!("Relay failed: {}", e);
}
});
let n_http_port = find_free_port();
let n_https_port = find_free_port();
let nginx_cert = gen_localhost_cert();
Expand All @@ -46,7 +53,7 @@ mod integration {
_ = example_gateway_http(gateway_port) => {
assert!(false, "Gateway is long running");
}
_ = listen_tcp(relay_port, gateway) => {
_ = relay_task => {
assert!(false, "Relay is long running");
}
_ = ohttp_req(n_https_port, nginx_cert_der) => {}
Expand All @@ -67,6 +74,13 @@ mod integration {
let nginx_cert = gen_localhost_cert();
let nginx_cert_der = cert_to_cert_der(&nginx_cert);
let socket_path_str = socket_path.to_str().unwrap();
let relay_handle =
listen_socket(socket_path_str, gateway).await.expect("Failed to listen on socket");
let relay_task = tokio::spawn(async move {
if let Err(e) = relay_handle.await {
eprintln!("Relay failed: {}", e);
}
});
let n_http_port = find_free_port();
let n_https_port = find_free_port();
let _nginx =
Expand All @@ -76,7 +90,7 @@ mod integration {
_ = example_gateway_http(gateway_port) => {
assert!(false, "Gateway is long running");
}
_ = listen_socket(socket_path_str, gateway) => {
_ = relay_task => {
assert!(false, "Relay is long running");
}
_ = ohttp_req(n_https_port, nginx_cert_der) => {}
Expand Down Expand Up @@ -286,10 +300,16 @@ mod integration {
{
let gateway_port = find_free_port();
let gateway = Uri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap();
let relay_port = find_free_port();
let nginx_cert = gen_localhost_cert();
let gateway_cert = gen_localhost_cert();
let gateway_cert_der = cert_to_cert_der(&gateway_cert);
let (relay_port, relay_handle) =
listen_tcp_on_free_port(gateway).await.expect("Failed to listen on free port");
let relay_task = tokio::spawn(async move {
if let Err(e) = relay_handle.await {
eprintln!("Relay failed: {}", e);
}
});
let n_http_port = find_free_port();
let n_https_port = find_free_port();
let _nginx = start_nginx(
Expand All @@ -303,7 +323,7 @@ mod integration {
_ = example_gateway_https(gateway_port, gateway_cert) => {
assert!(false, "Gateway is long running");
}
_ = listen_tcp(relay_port, gateway) => {
_ = relay_task => {
assert!(false, "Relay is long running");
}
_ = client_fn(n_http_port, gateway_port, gateway_cert_der) => {}
Expand Down Expand Up @@ -434,6 +454,15 @@ mod integration {
.spawn()
.expect("Failed to start nginx");

let start = std::time::Instant::now();
let timeout = std::time::Duration::from_secs(5);
while start.elapsed() < timeout {
if let Ok(_) = std::net::TcpStream::connect(format!("127.0.0.1:{}", n_https_port)) {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
Comment on lines +459 to +464

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy path and error path are not distinguished, which is strange to me... if the timeout elapses and nginx was not yet listening, why proceed?


// Keep the config file open as long as NGINX is using it
std::mem::forget(config_file);

Expand Down
Loading