diff --git a/babushka-core/src/client/client_cmd.rs b/babushka-core/src/client/client_cmd.rs index 4b46d1718d..688931dfbe 100644 --- a/babushka-core/src/client/client_cmd.rs +++ b/babushka-core/src/client/client_cmd.rs @@ -1,14 +1,26 @@ +use super::get_redis_connection_info; +use super::reconnecting_connection::ReconnectingConnection; use crate::connection_request::{ConnectionRequest, TlsMode}; use crate::retry_strategies::RetryStrategy; +use logger_core::{log_debug, log_trace}; use redis::RedisResult; +use std::sync::Arc; +use tokio::task; -use super::get_redis_connection_info; -use super::reconnecting_connection::ReconnectingConnection; +struct DropWrapper { + /// Connection to the primary node in the client. + primary: ReconnectingConnection, +} + +impl Drop for DropWrapper { + fn drop(&mut self) { + self.primary.mark_as_dropped(); + } +} #[derive(Clone)] pub struct ClientCMD { - /// Connection to the primary node in the client. - primary: ReconnectingConnection, + inner: Arc, } impl ClientCMD { @@ -27,27 +39,80 @@ impl ClientCMD { tls_mode, ) .await?; - - Ok(Self { primary }) + Self::start_heartbeat(primary.clone()); + Ok(Self { + inner: Arc::new(DropWrapper { primary }), + }) } pub async fn send_packed_command( &mut self, cmd: &redis::Cmd, ) -> redis::RedisResult { - self.primary.send_packed_command(cmd).await + log_trace("ClientCMD", "sending command"); + let mut connection = self.inner.primary.get_connection().await?; + let result = connection.send_packed_command(cmd).await; + match result { + Err(err) if err.is_connection_dropped() => { + self.inner.primary.reconnect().await; + Err(err) + } + _ => result, + } } - pub(super) async fn send_packed_commands( + pub async fn send_packed_commands( &mut self, cmd: &redis::Pipeline, offset: usize, count: usize, ) -> redis::RedisResult> { - self.primary.send_packed_commands(cmd, offset, count).await + let mut connection = self.inner.primary.get_connection().await?; + let result = connection.send_packed_commands(cmd, offset, count).await; + match result { + Err(err) if err.is_connection_dropped() => { + self.inner.primary.reconnect().await; + Err(err) + } + _ => result, + } } pub(super) fn get_db(&self) -> i64 { - self.primary.get_db() + self.inner.primary.get_db() + } + + fn start_heartbeat(reconnecting_connection: ReconnectingConnection) { + task::spawn(async move { + loop { + tokio::time::sleep(super::HEARTBEAT_SLEEP_DURATION).await; + if reconnecting_connection.is_dropped() { + log_debug( + "ClientCMD", + "heartbeat stopped after connection was dropped", + ); + // Client was dropped, heartbeat can stop. + return; + } + + let Some(mut connection) = reconnecting_connection.try_get_connection().await else { + log_debug( + "ClientCMD", + "heartbeat stopped while connection is reconnecting", + ); + // Client is reconnecting.. + continue; + }; + log_debug("ClientCMD", "performing heartbeat"); + if connection + .send_packed_command(&redis::cmd("PING")) + .await + .is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal()) + { + log_debug("ClientCMD", "heartbeat triggered reconnect"); + reconnecting_connection.reconnect().await; + } + } + }); } } diff --git a/babushka-core/src/client/mod.rs b/babushka-core/src/client/mod.rs index a5acb0c854..4d4d25c867 100644 --- a/babushka-core/src/client/mod.rs +++ b/babushka-core/src/client/mod.rs @@ -11,6 +11,8 @@ use std::time::Duration; mod client_cmd; mod reconnecting_connection; +pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1); + pub trait BabushkaClient: ConnectionLike + Send + Clone {} impl BabushkaClient for MultiplexedConnection {} diff --git a/babushka-core/src/client/reconnecting_connection.rs b/babushka-core/src/client/reconnecting_connection.rs index 8806706064..6a0bcb711d 100644 --- a/babushka-core/src/client/reconnecting_connection.rs +++ b/babushka-core/src/client/reconnecting_connection.rs @@ -36,23 +36,9 @@ struct InnerReconnectingConnection { backend: ConnectionBackend, } -/// The separation between an inner and outer connection is because the outer connection is clonable, and the inner connection needs to be dropped when no outer connection exists. -struct DropWrapper(Arc); - -impl Drop for DropWrapper { - fn drop(&mut self) { - self.0 - .backend - .client_dropped_flagged - .store(true, Ordering::Relaxed); - } -} - #[derive(Clone)] pub(super) struct ReconnectingConnection { - /// All of the connection's clones point to the same internal wrapper, which will be dropped only once, - /// when all of the clones have been dropped. - inner: Arc, + inner: Arc, } async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult { @@ -75,10 +61,10 @@ async fn try_create_connection( let connection = Retry::spawn(retry_strategy.get_iterator(), action).await?; Ok(ReconnectingConnection { - inner: Arc::new(DropWrapper(Arc::new(InnerReconnectingConnection { + inner: Arc::new(InnerReconnectingConnection { state: Mutex::new(ConnectionState::Connected(connection)), backend: connection_backend, - }))), + }), }) } @@ -112,7 +98,7 @@ impl ReconnectingConnection { connection_retry_strategy: RetryStrategy, redis_connection_info: RedisConnectionInfo, tls_mode: TlsMode, - ) -> RedisResult { + ) -> RedisResult { log_debug( "connection creation", format!("Attempting connection to {address}"), @@ -131,29 +117,44 @@ impl ReconnectingConnection { Ok(connection) } - async fn get_connection(&self) -> Result { + pub(super) fn is_dropped(&self) -> bool { + self.inner + .backend + .client_dropped_flagged + .load(Ordering::Relaxed) + } + + pub(super) fn mark_as_dropped(&self) { + self.inner + .backend + .client_dropped_flagged + .store(true, Ordering::Relaxed) + } + + pub(super) async fn try_get_connection(&self) -> Option { + let guard = self.inner.state.lock().await; + if let ConnectionState::Connected(connection) = &*guard { + Some(connection.clone()) + } else { + None + } + } + + pub(super) async fn get_connection(&self) -> Result { loop { - self.inner - .0 - .backend - .connection_available_signal - .wait() - .await; - { - let guard = self.inner.0.state.lock().await; - if let ConnectionState::Connected(connection) = &*guard { - return Ok(connection.clone()); - } - }; + self.inner.backend.connection_available_signal.wait().await; + if let Some(connection) = self.try_get_connection().await { + return Ok(connection); + } } } - async fn reconnect(&self) { + pub(super) async fn reconnect(&self) { { - let mut guard = self.inner.0.state.lock().await; + let mut guard = self.inner.state.lock().await; match &*guard { ConnectionState::Connected(_) => { - self.inner.0.backend.connection_available_signal.reset(); + self.inner.backend.connection_available_signal.reset(); } _ => { log_trace("reconnect", "already started"); @@ -164,18 +165,14 @@ impl ReconnectingConnection { *guard = ConnectionState::Reconnecting; }; log_debug("reconnect", "starting"); - let inner_connection_clone = self.inner.0.clone(); + let connection_clone = self.clone(); // The reconnect task is spawned instead of awaited here, so that the reconnect attempt will continue in the // background, regardless of whether the calling task is dropped or not. task::spawn(async move { - let client = &inner_connection_clone.backend.connection_info; + let client = &connection_clone.inner.backend.connection_info; for sleep_duration in internal_retry_iterator() { - if inner_connection_clone - .backend - .client_dropped_flagged - .load(Ordering::Relaxed) - { - log_trace( + if connection_clone.is_dropped() { + log_debug( "ReconnectingConnection", "reconnect stopped after client was dropped", ); @@ -185,13 +182,16 @@ impl ReconnectingConnection { log_debug("connection creation", "Creating multiplexed connection"); match get_multiplexed_connection(client).await { Ok(connection) => { - let mut guard = inner_connection_clone.state.lock().await; - log_debug("reconnect", "completed succesfully"); - inner_connection_clone - .backend - .connection_available_signal - .set(); - *guard = ConnectionState::Connected(connection); + { + let mut guard = connection_clone.inner.state.lock().await; + log_debug("reconnect", "completed succesfully"); + connection_clone + .inner + .backend + .connection_available_signal + .set(); + *guard = ConnectionState::Connected(connection); + } return; } Err(_) => tokio::time::sleep(sleep_duration).await, @@ -200,41 +200,8 @@ impl ReconnectingConnection { }); } - pub(super) async fn send_packed_command( - &mut self, - cmd: &redis::Cmd, - ) -> redis::RedisResult { - log_trace("ReconnectingConnection", "sending command"); - let mut connection = self.get_connection().await?; - let result = connection.send_packed_command(cmd).await; - match result { - Err(err) if err.is_connection_dropped() => { - self.reconnect().await; - Err(err) - } - _ => result, - } - } - - pub(super) async fn send_packed_commands( - &mut self, - cmd: &redis::Pipeline, - offset: usize, - count: usize, - ) -> redis::RedisResult> { - let mut connection = self.get_connection().await?; - let result = connection.send_packed_commands(cmd, offset, count).await; - match result { - Err(err) if err.is_connection_dropped() => { - self.reconnect().await; - Err(err) - } - _ => result, - } - } - pub(super) fn get_db(&self) -> i64 { - let guard = self.inner.0.state.blocking_lock(); + let guard = self.inner.state.blocking_lock(); match &*guard { ConnectionState::Connected(connection) => connection.get_db(), _ => -1, diff --git a/babushka-core/tests/test_client_cmd.rs b/babushka-core/tests/test_client_cmd.rs index b2c387d7f3..d118417efd 100644 --- a/babushka-core/tests/test_client_cmd.rs +++ b/babushka-core/tests/test_client_cmd.rs @@ -23,6 +23,7 @@ mod client_cmd_tests { let address = server.get_client_addr(); drop(server); + // we use another thread, so that the creation of the server won't block the client work. let thread = std::thread::spawn(move || { block_on_all(async move { let mut get_command = redis::Cmd::new(); @@ -53,4 +54,41 @@ mod client_cmd_tests { thread.join().unwrap(); }); } + + #[rstest] + #[timeout(LONG_CMD_TEST_TIMEOUT)] + fn test_detect_disconnect_and_reconnect_using_heartbeat(#[values(false, true)] use_tls: bool) { + let (sender, receiver) = tokio::sync::oneshot::channel(); + block_on_all(async move { + let mut test_basics = setup_test_basics(use_tls).await; + let server = test_basics.server; + let address = server.get_client_addr(); + println!("dropping server"); + drop(server); + + // we use another thread, so that the creation of the server won't block the client work. + std::thread::spawn(move || { + block_on_all(async move { + let new_server = RedisServer::new_with_addr_and_modules(address.clone(), &[]); + wait_for_server_to_become_ready(&address).await; + let _ = sender.send(new_server); + }) + }); + + let _new_server = receiver.await; + tokio::time::sleep(babushka::client::HEARTBEAT_SLEEP_DURATION + Duration::from_secs(1)) + .await; + + let mut get_command = redis::Cmd::new(); + get_command + .arg("GET") + .arg("test_detect_disconnect_and_reconnect_using_heartbeat"); + let get_result = test_basics + .client + .send_packed_command(&get_command) + .await + .unwrap(); + assert_eq!(get_result, Value::Nil); + }); + } }