Skip to content

Commit

Permalink
server: add sessions timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
conradoplg committed Nov 29, 2024
1 parent 319bd69 commit b268e5b
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 47 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.11.0", features = ["v4", "fast-rng", "serde"] }
xeddsa = "1.0.2"
futures-util = "0.3.31"
futures = "0.3.31"

[dev-dependencies]
axum-test = "16.4.0"
Expand Down
94 changes: 54 additions & 40 deletions server/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ pub(crate) async fn create_new_session(
// Create new session object.
let id = Uuid::new_v4();

let mut state = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();

// Save session ID in global state
for pubkey in &args.pubkeys {
state
.sessions_by_pubkey
sessions_by_pubkey
.entry(pubkey.clone())
.or_default()
.insert(id);
Expand All @@ -125,7 +125,7 @@ pub(crate) async fn create_new_session(
queue: Default::default(),
};
// Save session into global state.
state.sessions.insert(id, session);
sessions.insert(id, session);

let user = CreateNewSessionOutput { session_id: id };
Ok(Json(user))
Expand All @@ -137,10 +137,9 @@ pub(crate) async fn list_sessions(
State(state): State<SharedState>,
user: User,
) -> Result<Json<ListSessionsOutput>, AppError> {
let state = state.sessions.read().unwrap();
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();

let session_ids = state
.sessions_by_pubkey
let session_ids = sessions_by_pubkey
.get(&user.pubkey)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default();
Expand All @@ -155,24 +154,22 @@ pub(crate) async fn get_session_info(
user: User,
Json(args): Json<GetSessionInfoArgs>,
) -> Result<Json<GetSessionInfoOutput>, AppError> {
let state_lock = state.sessions.read().unwrap();
let sessions = state.sessions.sessions.read().unwrap();
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();

let sessions = state_lock
.sessions_by_pubkey
.get(&user.pubkey)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;
let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;

if !sessions.contains(&args.session_id) {
if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
}

let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
Expand All @@ -194,15 +191,14 @@ pub(crate) async fn send(
Json(args): Json<SendArgs>,
) -> Result<(), AppError> {
// Get the mutex lock to read and write from the state
let mut state_lock = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();

let session = state_lock
.sessions
.get_mut(&args.session_id)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
// TODO: change to get_mut and modify in-place, if HashMapDelay ever
// adds support to it
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;

let recipients = if args.recipients.is_empty() {
vec![Vec::new()]
Expand All @@ -219,6 +215,7 @@ pub(crate) async fn send(
msg: args.msg.clone(),
});
}
sessions.insert(args.session_id, session);

Ok(())
}
Expand All @@ -232,23 +229,39 @@ pub(crate) async fn receive(
Json(args): Json<ReceiveArgs>,
) -> Result<Json<ReceiveOutput>, AppError> {
// Get the mutex lock to read and write from the state
let mut state_lock = state.sessions.write().unwrap();
let sessions = state.sessions.sessions.read().unwrap();

let session = state_lock
.sessions
.get_mut(&args.session_id)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
// TODO: change to get_mut and modify in-place, if HashMapDelay ever
// adds support to it. This will also simplify the code since
// we have to do a workaround in order to not renew the timeout if there
// are no messages. See https://github.com/AgeManning/delay_map/issues/26
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;

let pubkey = if user.pubkey == session.coordinator_pubkey && args.as_coordinator {
Vec::new()
} else {
user.pubkey
};

let msgs = session.queue.entry(pubkey).or_default().drain(..).collect();
// If there are no new messages, we don't want to renew the timeout.
// Thus only if there are new messages we drop the read-only lock
// to get the write lock and re-insert the updated session.
let msgs = if session.queue.contains_key(&pubkey) {
drop(sessions);
let mut sessions = state.sessions.sessions.write().unwrap();
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let msgs = session.queue.entry(pubkey).or_default().drain(..).collect();
sessions.insert(args.session_id, session);
msgs
} else {
vec![]
};

Ok(Json(ReceiveOutput { msgs }))
}
Expand All @@ -260,21 +273,22 @@ pub(crate) async fn close_session(
user: User,
Json(args): Json<CloseSessionArgs>,
) -> Result<Json<()>, AppError> {
let mut state = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();

let sessions = state.sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;

if !sessions.contains(&args.session_id) {
if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
}

let session = state.sessions.get(&args.session_id).ok_or(AppError(
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid session ID").into(),
))?;
Expand All @@ -287,10 +301,10 @@ pub(crate) async fn close_session(
}

for username in session.pubkeys.clone() {
if let Some(v) = state.sessions_by_pubkey.get_mut(&username) {
if let Some(v) = sessions_by_pubkey.get_mut(&username) {
v.remove(&args.session_id);
}
}
state.sessions.remove(&args.session_id);
sessions.remove(&args.session_id);
Ok(Json(()))
}
103 changes: 96 additions & 7 deletions server/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
use std::{
collections::{HashMap, HashSet, VecDeque},
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
time::Duration,
};

use delay_map::{HashMapDelay, HashSetDelay};
use futures::{Stream, StreamExt as _};
use uuid::Uuid;

use crate::Msg;

/// How long a session stays open.
const SESSION_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60 * 24);
/// How long a challenge can be replied to.
const CHALLENGE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
/// How long an acesss token lasts.
const ACCESS_TOKEN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60);

/// Helper struct that allows calling `next()` on a `Stream` behind a `RwLock`
/// (namely a `HashMapDelay` or `HashSetDelay` in our case) without locking
/// the `RwLock` while waiting.
// From https://users.rust-lang.org/t/how-do-i-poll-a-stream-behind-a-rwlock/121787/2
struct RwLockStream<'a, T>(pub &'a RwLock<T>);

impl<'a, T: Stream + Unpin> Stream for RwLockStream<'a, T> {
type Item = T::Item;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<<Self as Stream>::Item>> {
self.0.write().unwrap().poll_next_unpin(cx)
}
}

/// A particular signing session.
#[derive(Debug)]
pub struct Session {
Expand All @@ -33,26 +55,93 @@ pub struct Session {
/// The global state of the server.
#[derive(Debug)]
pub struct AppState {
pub(crate) sessions: Arc<RwLock<SessionState>>,
pub(crate) sessions: SessionState,
pub(crate) challenges: Arc<RwLock<HashSetDelay<Uuid>>>,
pub(crate) access_tokens: Arc<RwLock<HashMapDelay<Uuid, Vec<u8>>>>,
}

#[derive(Debug, Default)]
pub struct SessionState {
/// Mapping of signing sessions by UUID.
pub(crate) sessions: HashMap<Uuid, Session>,
pub(crate) sessions_by_pubkey: HashMap<Vec<u8>, HashSet<Uuid>>,
pub(crate) sessions: Arc<RwLock<HashMapDelay<Uuid, Session>>>,
pub(crate) sessions_by_pubkey: Arc<RwLock<HashMap<Vec<u8>, HashSet<Uuid>>>>,
}

impl SessionState {
/// Create a new SessionState
pub fn new(timeout: Duration) -> Self {
Self {
sessions: RwLock::new(HashMapDelay::new(timeout)).into(),
sessions_by_pubkey: Default::default(),
}
}
}

impl AppState {
pub async fn new() -> Result<SharedState, Box<dyn std::error::Error>> {
let state = Self {
sessions: Default::default(),
let state = Arc::new(Self {
sessions: SessionState::new(SESSION_TIMEOUT),
challenges: RwLock::new(HashSetDelay::new(CHALLENGE_TIMEOUT)).into(),
access_tokens: RwLock::new(HashMapDelay::new(ACCESS_TOKEN_TIMEOUT)).into(),
};
Ok(Arc::new(state))
});

// In order to effectively removed timed out entries, we need to
// repeatedly call `next()` on them.
// These tasks will just run forever and will stop when the server stops.

let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.sessions.sessions).next().await {
Some(Ok((uuid, session))) => {
tracing::debug!("session {} timed out", uuid);
let mut sessions_by_pubkey =
state_clone.sessions.sessions_by_pubkey.write().unwrap();
for pubkey in session.pubkeys {
if let Some(sessions) = sessions_by_pubkey.get_mut(&pubkey) {
sessions.remove(&uuid);
}
}
}
_ => {
// Annoyingly, if the map is empty, it returns
// immediately instead of waiting for an entry to be
// inserted and waiting for that to timeout. To avoid a
// busy loop when the map is empty, we sleep for a bit.
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
// TODO: we could refactor these two loops with a generic function
// but it's just simpler to do this directly currently
let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.challenges).next().await {
Some(Ok(challenge)) => {
tracing::debug!("challenge {} timed out", challenge);
}
_ => {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.access_tokens).next().await {
Some(Ok((access_token, _pubkey))) => {
tracing::debug!("access_token {} timed out", access_token);
}
_ => {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
Ok(state)
}
}

Expand Down

0 comments on commit b268e5b

Please sign in to comment.