diff --git a/Cargo.lock b/Cargo.lock index d77d4b9..0c7b421 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -324,6 +324,7 @@ dependencies = [ "reqwest", "rstest", "rusqlite", + "rustix", "scopeguard", "sd-notify", "serde", @@ -333,6 +334,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sha2", + "tempfile", "tokio", "tokio-stream", "toml", diff --git a/blahd/Cargo.toml b/blahd/Cargo.toml index a113186..6521e62 100644 --- a/blahd/Cargo.toml +++ b/blahd/Cargo.toml @@ -21,6 +21,7 @@ parking_lot = "0.12" # Maybe no better performance, just that we hate poisoning. rand = "0.8" reqwest = "0.12" rusqlite = "0.32" +rustix = { version = "0.38", features = ["net"] } sd-notify = "0.4" serde = { version = "1", features = ["derive"] } serde-constant = "0.1" @@ -45,6 +46,7 @@ nix = { version = "0.29.0", features = ["fs", "process", "signal"] } reqwest = { version = "0.12.7", features = ["json"] } rstest = { version = "0.22.0", default-features = false } scopeguard = "1.2.0" +tempfile = "3.12.0" [lints] workspace = true diff --git a/blahd/config.example.toml b/blahd/config.example.toml index 0d07efa..d22b983 100644 --- a/blahd/config.example.toml +++ b/blahd/config.example.toml @@ -24,6 +24,7 @@ address = "localhost:8080" # Use systemd socket activation mechanism to get listener fd from envvars. # See also sd_listen_fds(3) and systemd.socket(5). +# NB. Currently only TCP sockets are supported. UNIX domain socket is TODO. #systemd = true [server] diff --git a/blahd/src/bin/blahd.rs b/blahd/src/bin/blahd.rs index c40cb42..fe4d06c 100644 --- a/blahd/src/bin/blahd.rs +++ b/blahd/src/bin/blahd.rs @@ -1,9 +1,8 @@ -use std::net::TcpListener; -use std::os::fd::FromRawFd; +use std::os::fd::{FromRawFd, OwnedFd}; use std::path::PathBuf; use std::sync::Arc; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use blahd::config::{Config, ListenConfig}; use blahd::{AppState, Database}; @@ -56,30 +55,44 @@ fn main() -> Result<()> { async fn main_serve(db: Database, config: Config) -> Result<()> { let st = AppState::new(db, config.server); - let listener = match &config.listen { - ListenConfig::Address(addr) => { - tracing::info!("listening on {addr:?}"); + let (listener_display, listener) = match &config.listen { + ListenConfig::Address(addr) => ( + format!("address {addr:?}"), tokio::net::TcpListener::bind(addr) .await - .context("failed to listen on socket")? - } + .context("failed to listen on socket")?, + ), ListenConfig::Systemd(_) => { - tracing::info!("listening on fd from environment"); + use rustix::net::{getsockname, SocketAddrAny}; + let [fd] = sd_notify::listen_fds() .context("failed to get fds from sd_listen_fds(3)")? .collect::>() .try_into() - .map_err(|_| anyhow!("more than one fds available from sd_listen_fds(3)"))?; + .map_err(|_| anyhow!("expecting exactly one fd from LISTEN_FDS"))?; // SAFETY: `fd` is valid by sd_listen_fds(3) protocol. - let listener = unsafe { TcpListener::from_raw_fd(fd) }; - listener - .set_nonblocking(true) - .context("failed to set socket non-blocking")?; - tokio::net::TcpListener::from_std(listener) - .context("failed to register async socket")? + let listener = unsafe { OwnedFd::from_raw_fd(fd) }; + + let addr = getsockname(&listener).context("failed to getsockname")?; + match addr { + SocketAddrAny::V4(_) | SocketAddrAny::V6(_) => { + let listener = std::net::TcpListener::from(listener); + listener + .set_nonblocking(true) + .context("failed to set socket non-blocking")?; + let listener = tokio::net::TcpListener::from_std(listener) + .context("failed to register async socket")?; + (format!("tcp socket {addr:?} from LISTEN_FDS"), listener) + } + // Unix socket support for axum is currently overly complex. + // WAIT: https://github.com/tokio-rs/axum/pull/2479 + _ => bail!("unsupported socket type from LISTEN_FDS: {addr:?}"), + } } }; + tracing::info!("listening on {listener_display}"); + let router = blahd::router(Arc::new(st)); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); diff --git a/blahd/tests/socket_activate.rs b/blahd/tests/socket_activate.rs index b7f4713..fab971b 100644 --- a/blahd/tests/socket_activate.rs +++ b/blahd/tests/socket_activate.rs @@ -2,7 +2,8 @@ use std::ffi::CString; use std::fs::File; use std::io::{Seek, Write}; use std::net::TcpListener; -use std::os::fd::{AsFd, AsRawFd}; +use std::os::fd::{AsFd, AsRawFd, OwnedFd}; +use std::os::unix::net::UnixListener; use std::process::abort; use std::ptr::null; use std::time::Duration; @@ -13,6 +14,7 @@ use nix::sys::memfd::{memfd_create, MemFdCreateFlag}; use nix::sys::signal::{kill, Signal}; use nix::sys::wait::{waitpid, WaitStatus}; use nix::unistd::{alarm, dup2, fork, getpid, ForkResult}; +use rstest::rstest; use tokio::io::stderr; const TIMEOUT_SEC: u32 = 1; @@ -28,10 +30,22 @@ systemd = true base_url = "http://example.com" "#; -#[test] -fn socket_activate() { - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let local_port = listener.local_addr().unwrap().port(); +#[rstest] +#[case::tcp(false)] +#[case::unix(true)] +fn socket_activate(#[case] unix_socket: bool) { + let socket_dir; + let (local_port, listener) = if unix_socket { + socket_dir = tempfile::tempdir().unwrap(); + let listener = UnixListener::bind(socket_dir.path().join("socket")).unwrap(); + // Port is unused. + (0, OwnedFd::from(listener)) + } else { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let local_port = listener.local_addr().unwrap().port(); + (local_port, OwnedFd::from(listener)) + }; + let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -89,11 +103,11 @@ fn socket_activate() { } } ForkResult::Parent { child } => { - { - scopeguard::defer! { - let _ = kill(child, Signal::SIGKILL); - } + let guard = scopeguard::guard((), |()| { + let _ = kill(child, Signal::SIGKILL); + }); + if !unix_socket { let resp = rt.block_on(async { let url = format!("http://127.0.0.1:{local_port}/_blah/room?filter=public"); let fut = async { @@ -111,13 +125,20 @@ fn socket_activate() { .unwrap() }); assert_eq!(resp, r#"{"rooms":[]}"#); + // Trigger the killer. + drop(guard); } let st = waitpid(child, None).unwrap(); - assert!( - matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)), - "unexpected exit status {st:?}", - ); + if unix_socket { + // Fail with unsupported error. + assert!(matches!(st, WaitStatus::Exited(_, 1))); + } else { + assert!( + matches!(st, WaitStatus::Signaled(_, Signal::SIGKILL, _)), + "unexpected exit status {st:?}", + ); + } } } }