diff --git a/Cargo.lock b/Cargo.lock index 4c31fdd..f3dbec9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2439,6 +2439,8 @@ dependencies = [ "frost-core", "frost-ed25519", "frost-rerandomized", + "futures", + "futures-util", "hex", "rand", "reddsa", diff --git a/server/Cargo.toml b/server/Cargo.toml index 5ffe17d..496fb67 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -26,6 +26,8 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.11.0", features = ["v4", "fast-rng", "serde"] } xeddsa = "1.0.2" +futures-util = "0.3.31" +futures = "0.3.31" hex = "0.4.3" [dev-dependencies] diff --git a/server/src/functions.rs b/server/src/functions.rs index d42033f..410e986 100644 --- a/server/src/functions.rs +++ b/server/src/functions.rs @@ -106,12 +106,12 @@ pub(crate) async fn create_new_session( // Create new session object. let id = Uuid::new_v4(); - let mut state = state.sessions.write().unwrap(); + let mut sessions = state.sessions.sessions.write().unwrap(); + let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap(); // Save session ID in global state for pubkey in &args.pubkeys { - state - .sessions_by_pubkey + sessions_by_pubkey .entry(pubkey.0.clone()) .or_default() .insert(id); @@ -125,7 +125,7 @@ pub(crate) async fn create_new_session( queue: Default::default(), }; // Save session into global state. - state.sessions.insert(id, session); + sessions.insert(id, session); let user = CreateNewSessionOutput { session_id: id }; Ok(Json(user)) @@ -137,10 +137,9 @@ pub(crate) async fn list_sessions( State(state): State, user: User, ) -> Result, AppError> { - let state = state.sessions.read().unwrap(); + let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap(); - let session_ids = state - .sessions_by_pubkey + let session_ids = sessions_by_pubkey .get(&user.pubkey) .map(|s| s.iter().cloned().collect()) .unwrap_or_default(); @@ -155,24 +154,22 @@ pub(crate) async fn get_session_info( user: User, Json(args): Json, ) -> Result, AppError> { - let state_lock = state.sessions.read().unwrap(); + let sessions = state.sessions.sessions.read().unwrap(); + let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap(); - let sessions = state_lock - .sessions_by_pubkey - .get(&user.pubkey) - .ok_or(AppError( - StatusCode::NOT_FOUND, - eyre!("user is not in any session").into(), - ))?; + let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError( + StatusCode::NOT_FOUND, + eyre!("user is not in any session").into(), + ))?; - if !sessions.contains(&args.session_id) { + if !user_sessions.contains(&args.session_id) { return Err(AppError( StatusCode::NOT_FOUND, eyre!("session ID not found").into(), )); } - let session = state_lock.sessions.get(&args.session_id).ok_or(AppError( + let session = sessions.get(&args.session_id).ok_or(AppError( StatusCode::NOT_FOUND, eyre!("session ID not found").into(), ))?; @@ -194,15 +191,14 @@ pub(crate) async fn send( Json(args): Json, ) -> Result<(), AppError> { // Get the mutex lock to read and write from the state - let mut state_lock = state.sessions.write().unwrap(); + let mut sessions = state.sessions.sessions.write().unwrap(); - let session = state_lock - .sessions - .get_mut(&args.session_id) - .ok_or(AppError( - StatusCode::NOT_FOUND, - eyre!("session ID not found").into(), - ))?; + // TODO: change to get_mut and modify in-place, if HashMapDelay ever + // adds support to it + let mut session = sessions.remove(&args.session_id).ok_or(AppError( + StatusCode::NOT_FOUND, + eyre!("session ID not found").into(), + ))?; let recipients = if args.recipients.is_empty() { vec![Vec::new()] @@ -219,6 +215,7 @@ pub(crate) async fn send( msg: args.msg.clone(), }); } + sessions.insert(args.session_id, session); Ok(()) } @@ -232,15 +229,16 @@ pub(crate) async fn receive( Json(args): Json, ) -> Result, AppError> { // Get the mutex lock to read and write from the state - let mut state_lock = state.sessions.write().unwrap(); + let sessions = state.sessions.sessions.read().unwrap(); - let session = state_lock - .sessions - .get_mut(&args.session_id) - .ok_or(AppError( - StatusCode::NOT_FOUND, - eyre!("session ID not found").into(), - ))?; + // TODO: change to get_mut and modify in-place, if HashMapDelay ever + // adds support to it. This will also simplify the code since + // we have to do a workaround in order to not renew the timeout if there + // are no messages. See https://github.com/AgeManning/delay_map/issues/26 + let session = sessions.get(&args.session_id).ok_or(AppError( + StatusCode::NOT_FOUND, + eyre!("session ID not found").into(), + ))?; let pubkey = if user.pubkey == session.coordinator_pubkey && args.as_coordinator { Vec::new() @@ -248,7 +246,22 @@ pub(crate) async fn receive( user.pubkey }; - let msgs = session.queue.entry(pubkey).or_default().drain(..).collect(); + // If there are no new messages, we don't want to renew the timeout. + // Thus only if there are new messages we drop the read-only lock + // to get the write lock and re-insert the updated session. + let msgs = if session.queue.contains_key(&pubkey) { + drop(sessions); + let mut sessions = state.sessions.sessions.write().unwrap(); + let mut session = sessions.remove(&args.session_id).ok_or(AppError( + StatusCode::NOT_FOUND, + eyre!("session ID not found").into(), + ))?; + let msgs = session.queue.entry(pubkey).or_default().drain(..).collect(); + sessions.insert(args.session_id, session); + msgs + } else { + vec![] + }; Ok(Json(ReceiveOutput { msgs })) } @@ -260,21 +273,22 @@ pub(crate) async fn close_session( user: User, Json(args): Json, ) -> Result, AppError> { - let mut state = state.sessions.write().unwrap(); + let mut sessions = state.sessions.sessions.write().unwrap(); + let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap(); - let sessions = state.sessions_by_pubkey.get(&user.pubkey).ok_or(AppError( + let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError( StatusCode::NOT_FOUND, eyre!("user is not in any session").into(), ))?; - if !sessions.contains(&args.session_id) { + if !user_sessions.contains(&args.session_id) { return Err(AppError( StatusCode::NOT_FOUND, eyre!("session ID not found").into(), )); } - let session = state.sessions.get(&args.session_id).ok_or(AppError( + let session = sessions.get(&args.session_id).ok_or(AppError( StatusCode::INTERNAL_SERVER_ERROR, eyre!("invalid session ID").into(), ))?; @@ -287,10 +301,10 @@ pub(crate) async fn close_session( } for username in session.pubkeys.clone() { - if let Some(v) = state.sessions_by_pubkey.get_mut(&username) { + if let Some(v) = sessions_by_pubkey.get_mut(&username) { v.remove(&args.session_id); } } - state.sessions.remove(&args.session_id); + sessions.remove(&args.session_id); Ok(Json(())) } diff --git a/server/src/state.rs b/server/src/state.rs index 5e4adb8..a3415bd 100644 --- a/server/src/state.rs +++ b/server/src/state.rs @@ -1,18 +1,40 @@ use std::{ collections::{HashMap, HashSet, VecDeque}, + pin::Pin, sync::{Arc, RwLock}, + task::{Context, Poll}, + time::Duration, }; use delay_map::{HashMapDelay, HashSetDelay}; +use futures::{Stream, StreamExt as _}; use uuid::Uuid; use crate::Msg; +/// How long a session stays open. +const SESSION_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60 * 24); /// How long a challenge can be replied to. const CHALLENGE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); /// How long an acesss token lasts. const ACCESS_TOKEN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60); +/// Helper struct that allows calling `next()` on a `Stream` behind a `RwLock` +/// (namely a `HashMapDelay` or `HashSetDelay` in our case) without locking +/// the `RwLock` while waiting. +// From https://users.rust-lang.org/t/how-do-i-poll-a-stream-behind-a-rwlock/121787/2 +struct RwLockStream<'a, T>(pub &'a RwLock); + +impl<'a, T: Stream + Unpin> Stream for RwLockStream<'a, T> { + type Item = T::Item; + fn poll_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll::Item>> { + self.0.write().unwrap().poll_next_unpin(cx) + } +} + /// A particular signing session. #[derive(Debug)] pub struct Session { @@ -33,7 +55,7 @@ pub struct Session { /// The global state of the server. #[derive(Debug)] pub struct AppState { - pub(crate) sessions: Arc>, + pub(crate) sessions: SessionState, pub(crate) challenges: Arc>>, pub(crate) access_tokens: Arc>>>, } @@ -41,18 +63,85 @@ pub struct AppState { #[derive(Debug, Default)] pub struct SessionState { /// Mapping of signing sessions by UUID. - pub(crate) sessions: HashMap, - pub(crate) sessions_by_pubkey: HashMap, HashSet>, + pub(crate) sessions: Arc>>, + pub(crate) sessions_by_pubkey: Arc, HashSet>>>, +} + +impl SessionState { + /// Create a new SessionState + pub fn new(timeout: Duration) -> Self { + Self { + sessions: RwLock::new(HashMapDelay::new(timeout)).into(), + sessions_by_pubkey: Default::default(), + } + } } impl AppState { pub async fn new() -> Result> { - let state = Self { - sessions: Default::default(), + let state = Arc::new(Self { + sessions: SessionState::new(SESSION_TIMEOUT), challenges: RwLock::new(HashSetDelay::new(CHALLENGE_TIMEOUT)).into(), access_tokens: RwLock::new(HashMapDelay::new(ACCESS_TOKEN_TIMEOUT)).into(), - }; - Ok(Arc::new(state)) + }); + + // In order to effectively removed timed out entries, we need to + // repeatedly call `next()` on them. + // These tasks will just run forever and will stop when the server stops. + + let state_clone = state.clone(); + tokio::task::spawn(async move { + loop { + match RwLockStream(&state_clone.sessions.sessions).next().await { + Some(Ok((uuid, session))) => { + tracing::debug!("session {} timed out", uuid); + let mut sessions_by_pubkey = + state_clone.sessions.sessions_by_pubkey.write().unwrap(); + for pubkey in session.pubkeys { + if let Some(sessions) = sessions_by_pubkey.get_mut(&pubkey) { + sessions.remove(&uuid); + } + } + } + _ => { + // Annoyingly, if the map is empty, it returns + // immediately instead of waiting for an entry to be + // inserted and waiting for that to timeout. To avoid a + // busy loop when the map is empty, we sleep for a bit. + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + } + }); + // TODO: we could refactor these two loops with a generic function + // but it's just simpler to do this directly currently + let state_clone = state.clone(); + tokio::task::spawn(async move { + loop { + match RwLockStream(&state_clone.challenges).next().await { + Some(Ok(challenge)) => { + tracing::debug!("challenge {} timed out", challenge); + } + _ => { + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + } + }); + let state_clone = state.clone(); + tokio::task::spawn(async move { + loop { + match RwLockStream(&state_clone.access_tokens).next().await { + Some(Ok((access_token, _pubkey))) => { + tracing::debug!("access_token {} timed out", access_token); + } + _ => { + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + } + }); + Ok(state) } }