Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: add sessions timeouts #386

Merged
merged 2 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.11.0", features = ["v4", "fast-rng", "serde"] }
xeddsa = "1.0.2"
futures-util = "0.3.31"
futures = "0.3.31"
hex = "0.4.3"

[dev-dependencies]
Expand Down
94 changes: 54 additions & 40 deletions server/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ pub(crate) async fn create_new_session(
// Create new session object.
let id = Uuid::new_v4();

let mut state = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();

// Save session ID in global state
for pubkey in &args.pubkeys {
state
.sessions_by_pubkey
sessions_by_pubkey
.entry(pubkey.0.clone())
.or_default()
.insert(id);
Expand All @@ -125,7 +125,7 @@ pub(crate) async fn create_new_session(
queue: Default::default(),
};
// Save session into global state.
state.sessions.insert(id, session);
sessions.insert(id, session);

let user = CreateNewSessionOutput { session_id: id };
Ok(Json(user))
Expand All @@ -137,10 +137,9 @@ pub(crate) async fn list_sessions(
State(state): State<SharedState>,
user: User,
) -> Result<Json<ListSessionsOutput>, AppError> {
let state = state.sessions.read().unwrap();
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();

let session_ids = state
.sessions_by_pubkey
let session_ids = sessions_by_pubkey
.get(&user.pubkey)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default();
Expand All @@ -155,24 +154,22 @@ pub(crate) async fn get_session_info(
user: User,
Json(args): Json<GetSessionInfoArgs>,
) -> Result<Json<GetSessionInfoOutput>, AppError> {
let state_lock = state.sessions.read().unwrap();
let sessions = state.sessions.sessions.read().unwrap();
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();

let sessions = state_lock
.sessions_by_pubkey
.get(&user.pubkey)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;
let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;

if !sessions.contains(&args.session_id) {
if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
}

let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
Expand All @@ -194,15 +191,14 @@ pub(crate) async fn send(
Json(args): Json<SendArgs>,
) -> Result<(), AppError> {
// Get the mutex lock to read and write from the state
let mut state_lock = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();

let session = state_lock
.sessions
.get_mut(&args.session_id)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
// TODO: change to get_mut and modify in-place, if HashMapDelay ever
// adds support to it
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;

let recipients = if args.recipients.is_empty() {
vec![Vec::new()]
Expand All @@ -219,6 +215,7 @@ pub(crate) async fn send(
msg: args.msg.clone(),
});
}
sessions.insert(args.session_id, session);

Ok(())
}
Expand All @@ -232,23 +229,39 @@ pub(crate) async fn receive(
Json(args): Json<ReceiveArgs>,
) -> Result<Json<ReceiveOutput>, AppError> {
// Get the mutex lock to read and write from the state
let mut state_lock = state.sessions.write().unwrap();
let sessions = state.sessions.sessions.read().unwrap();

let session = state_lock
.sessions
.get_mut(&args.session_id)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
// TODO: change to get_mut and modify in-place, if HashMapDelay ever
// adds support to it. This will also simplify the code since
// we have to do a workaround in order to not renew the timeout if there
// are no messages. See https://github.com/AgeManning/delay_map/issues/26
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;

let pubkey = if user.pubkey == session.coordinator_pubkey && args.as_coordinator {
Vec::new()
} else {
user.pubkey
};

let msgs = session.queue.entry(pubkey).or_default().drain(..).collect();
// If there are no new messages, we don't want to renew the timeout.
// Thus only if there are new messages we drop the read-only lock
// to get the write lock and re-insert the updated session.
let msgs = if session.queue.contains_key(&pubkey) {
drop(sessions);
let mut sessions = state.sessions.sessions.write().unwrap();
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let msgs = session.queue.entry(pubkey).or_default().drain(..).collect();
sessions.insert(args.session_id, session);
msgs
} else {
vec![]
};

Ok(Json(ReceiveOutput { msgs }))
}
Expand All @@ -260,21 +273,22 @@ pub(crate) async fn close_session(
user: User,
Json(args): Json<CloseSessionArgs>,
) -> Result<Json<()>, AppError> {
let mut state = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();

let sessions = state.sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;

if !sessions.contains(&args.session_id) {
if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
}

let session = state.sessions.get(&args.session_id).ok_or(AppError(
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid session ID").into(),
))?;
Expand All @@ -287,10 +301,10 @@ pub(crate) async fn close_session(
}

for username in session.pubkeys.clone() {
if let Some(v) = state.sessions_by_pubkey.get_mut(&username) {
if let Some(v) = sessions_by_pubkey.get_mut(&username) {
v.remove(&args.session_id);
}
}
state.sessions.remove(&args.session_id);
sessions.remove(&args.session_id);
Ok(Json(()))
}
105 changes: 96 additions & 9 deletions server/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
use std::{
collections::{HashMap, HashSet, VecDeque},
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
time::Duration,
};

use delay_map::{HashMapDelay, HashSetDelay};
use futures::{Stream, StreamExt as _};
use uuid::Uuid;

use crate::Msg;

/// How long a session stays open.
const SESSION_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60 * 24);
/// How long a challenge can be replied to.
const CHALLENGE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
/// How long an acesss token lasts.
const ACCESS_TOKEN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60);

/// Helper struct that allows calling `next()` on a `Stream` behind a `RwLock`
/// (namely a `HashMapDelay` or `HashSetDelay` in our case) without locking
/// the `RwLock` while waiting.
// From https://users.rust-lang.org/t/how-do-i-poll-a-stream-behind-a-rwlock/121787/2
struct RwLockStream<'a, T>(pub &'a RwLock<T>);

impl<T: Stream + Unpin> Stream for RwLockStream<'_, T> {
type Item = T::Item;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<<Self as Stream>::Item>> {
self.0.write().unwrap().poll_next_unpin(cx)
}
}

/// A particular signing session.
#[derive(Debug)]
pub struct Session {
Expand All @@ -22,8 +44,6 @@ pub struct Session {
pub(crate) coordinator_pubkey: Vec<u8>,
/// The number of signers in the session.
pub(crate) num_signers: u16,
/// The set of identifiers for the session.
// pub(crate) identifiers: BTreeSet<SerializedIdentifier>,
/// The number of messages being simultaneously signed.
pub(crate) message_count: u8,
/// The message queue.
Expand All @@ -33,26 +53,93 @@ pub struct Session {
/// The global state of the server.
#[derive(Debug)]
pub struct AppState {
pub(crate) sessions: Arc<RwLock<SessionState>>,
pub(crate) sessions: SessionState,
pub(crate) challenges: Arc<RwLock<HashSetDelay<Uuid>>>,
pub(crate) access_tokens: Arc<RwLock<HashMapDelay<Uuid, Vec<u8>>>>,
}

#[derive(Debug, Default)]
pub struct SessionState {
/// Mapping of signing sessions by UUID.
pub(crate) sessions: HashMap<Uuid, Session>,
pub(crate) sessions_by_pubkey: HashMap<Vec<u8>, HashSet<Uuid>>,
pub(crate) sessions: Arc<RwLock<HashMapDelay<Uuid, Session>>>,
pub(crate) sessions_by_pubkey: Arc<RwLock<HashMap<Vec<u8>, HashSet<Uuid>>>>,
}

impl SessionState {
/// Create a new SessionState
pub fn new(timeout: Duration) -> Self {
Self {
sessions: RwLock::new(HashMapDelay::new(timeout)).into(),
sessions_by_pubkey: Default::default(),
}
}
}

impl AppState {
pub async fn new() -> Result<SharedState, Box<dyn std::error::Error>> {
let state = Self {
sessions: Default::default(),
let state = Arc::new(Self {
sessions: SessionState::new(SESSION_TIMEOUT),
challenges: RwLock::new(HashSetDelay::new(CHALLENGE_TIMEOUT)).into(),
access_tokens: RwLock::new(HashMapDelay::new(ACCESS_TOKEN_TIMEOUT)).into(),
};
Ok(Arc::new(state))
});

// In order to effectively removed timed out entries, we need to
// repeatedly call `next()` on them.
// These tasks will just run forever and will stop when the server stops.

let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.sessions.sessions).next().await {
Some(Ok((uuid, session))) => {
tracing::debug!("session {} timed out", uuid);
let mut sessions_by_pubkey =
state_clone.sessions.sessions_by_pubkey.write().unwrap();
for pubkey in session.pubkeys {
if let Some(sessions) = sessions_by_pubkey.get_mut(&pubkey) {
sessions.remove(&uuid);
}
}
}
_ => {
// Annoyingly, if the map is empty, it returns
// immediately instead of waiting for an entry to be
// inserted and waiting for that to timeout. To avoid a
// busy loop when the map is empty, we sleep for a bit.
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
// TODO: we could refactor these two loops with a generic function
// but it's just simpler to do this directly currently
let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.challenges).next().await {
Some(Ok(challenge)) => {
tracing::debug!("challenge {} timed out", challenge);
}
_ => {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.access_tokens).next().await {
Some(Ok((access_token, _pubkey))) => {
tracing::debug!("access_token {} timed out", access_token);
}
_ => {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
Ok(state)
}
}

Expand Down
Loading