From 664da0ebed51cb8dce6973986e045bc284ecb91b Mon Sep 17 00:00:00 2001 From: Shachar Langbeheim Date: Mon, 17 Jul 2023 10:23:17 +0000 Subject: [PATCH] Add CMD read from replica. --- babushka-core/src/client/client_cmd.rs | 281 +++++++++++++----- .../src/client/reconnecting_connection.rs | 94 ++++-- babushka-core/tests/test_client_cmd.rs | 152 +++++++++- babushka-core/tests/test_socket_listener.rs | 9 +- babushka-core/tests/utilities/mocks.rs | 8 + 5 files changed, 437 insertions(+), 107 deletions(-) diff --git a/babushka-core/src/client/client_cmd.rs b/babushka-core/src/client/client_cmd.rs index 9532512c4e..e69bb3ec2c 100644 --- a/babushka-core/src/client/client_cmd.rs +++ b/babushka-core/src/client/client_cmd.rs @@ -1,22 +1,33 @@ use super::get_redis_connection_info; use super::reconnecting_connection::ReconnectingConnection; -use crate::connection_request::{ConnectionRequest, TlsMode}; +use crate::connection_request::{AddressInfo, ConnectionRequest, TlsMode}; use crate::retry_strategies::RetryStrategy; use futures::{stream, StreamExt}; -use logger_core::{log_debug, log_trace}; -use redis::RedisError; +use logger_core::{log_debug, log_trace, log_warn}; +use protobuf::EnumOrUnknown; +use redis::{RedisError, RedisResult, Value}; use std::sync::Arc; use tokio::task; +enum ReadFromReplicaStrategy { + AlwaysFromPrimary, + RoundRobin { + latest_read_replica_index: Arc, + }, +} + struct DropWrapper { /// Connection to the primary node in the client. - primary: ReconnectingConnection, - replicas: Vec, + primary_index: usize, + nodes: Vec, + read_from_replica_strategy: ReadFromReplicaStrategy, } impl Drop for DropWrapper { fn drop(&mut self) { - self.primary.mark_as_dropped(); + for node in self.nodes.iter() { + node.mark_as_dropped(); + } } } @@ -27,7 +38,6 @@ pub struct ClientCMD { pub enum ClientCMDConnectionError { NoAddressesProvided, - NoPrimaryFound, FailedConnection(Vec<(String, RedisError)>), } @@ -37,9 +47,6 @@ impl std::fmt::Debug for ClientCMDConnectionError { ClientCMDConnectionError::NoAddressesProvided => { write!(f, "No addresses provided") } - ClientCMDConnectionError::NoPrimaryFound => { - write!(f, "No primary node found") - } ClientCMDConnectionError::FailedConnection(errs) => { writeln!(f, "Received errors:")?; for (address, error) in errs { @@ -63,82 +70,119 @@ impl ClientCMD { get_redis_connection_info(connection_request.authentication_info.0); let tls_mode = connection_request.tls_mode.enum_value_or(TlsMode::NoTls); - let connections = stream::iter(connection_request.addresses.iter()) + let node_count = connection_request.addresses.len(); + let mut stream = stream::iter(connection_request.addresses.iter()) .map(|address| async { - ( - format!("{}:{}", address.host, address.port), - async { - let reconnecting_connection = ReconnectingConnection::new( - address, - retry_strategy.clone(), - redis_connection_info.clone(), - tls_mode, - ) - .await?; - let mut multiplexed_connection = - reconnecting_connection.get_connection().await?; - let replication_status = multiplexed_connection - .send_packed_command(redis::cmd("INFO").arg("REPLICATION")) - .await?; - Ok((reconnecting_connection, replication_status)) - } - .await, + get_connection_and_replication_info( + address, + &retry_strategy, + &redis_connection_info, + tls_mode, ) + .await + .map_err(|err| (format!("{}:{}", address.host, address.port), err)) }) - .buffer_unordered(connection_request.addresses.len()) - .collect::>() - .await; - - if connections.iter().any(|(_, result)| result.is_err()) { - let addresses_and_errors: Vec<(String, RedisError)> = connections - .into_iter() - .filter_map(|(address, result)| result.err().map(|err| (address, err))) - .collect(); - return Err(ClientCMDConnectionError::FailedConnection( - addresses_and_errors, - )); + .buffer_unordered(node_count); + + let mut nodes = Vec::with_capacity(node_count); + let mut addresses_and_errors = Vec::with_capacity(node_count); + let mut primary_index = None; + while let Some(result) = stream.next().await { + match result { + Ok((connection, replication_status)) => { + nodes.push(connection); + if primary_index.is_none() + && redis::from_redis_value::(&replication_status) + .is_ok_and(|val| val.contains("role:master")) + { + primary_index = Some(nodes.len() - 1); + } + } + Err((address, (connection, err))) => { + nodes.push(connection); + addresses_and_errors.push((address, err)); + } + } } - let results: Vec<(ReconnectingConnection, redis::Value)> = connections - .into_iter() - .map(|(_, result)| result.unwrap()) - .collect(); - let Some(primary_index) = results.iter().position(|(_, replication_status)| - redis::from_redis_value::(replication_status).ok().and_then(|val|if val.contains("role:master") { Some(())} else {None}).is_some() - ) else { - return Err(ClientCMDConnectionError::NoPrimaryFound); + + let Some(primary_index) = primary_index else { + return Err(ClientCMDConnectionError::FailedConnection(addresses_and_errors)); }; - let mut connections: Vec = results - .into_iter() - .map(|(connection, _)| connection) - .collect(); - let Some(primary) = connections - .drain(primary_index..primary_index+1) - .next() else { - return Err(ClientCMDConnectionError::NoPrimaryFound); - }; + if !addresses_and_errors.is_empty() { + log_warn( + "client creation", + format!( + "Failed to connect to {addresses_and_errors:?}, will attempt to reconnect." + ), + ); + } + let read_from_replica_strategy = + get_read_from_replica_strategy(&connection_request.read_from_replica_strategy); - Self::start_heartbeat(primary.clone()); - for connection in connections.iter() { - Self::start_heartbeat(connection.clone()); + for node in nodes.iter() { + Self::start_heartbeat(node.clone()); } + Ok(Self { inner: Arc::new(DropWrapper { - primary, - replicas: connections, + primary_index, + nodes, + read_from_replica_strategy, }), }) } - pub async fn send_packed_command( - &mut self, - cmd: &redis::Cmd, - ) -> redis::RedisResult { + fn get_primary_connection(&self) -> &ReconnectingConnection { + self.inner.nodes.get(self.inner.primary_index).unwrap() + } + + fn get_connection(&self, cmd: &redis::Cmd) -> &ReconnectingConnection { + if !is_readonly_cmd(cmd) || self.inner.nodes.len() == 1 { + return self.get_primary_connection(); + } + match &self.inner.read_from_replica_strategy { + ReadFromReplicaStrategy::AlwaysFromPrimary => self.get_primary_connection(), + ReadFromReplicaStrategy::RoundRobin { + latest_read_replica_index, + } => { + let initial_index = latest_read_replica_index + .load(std::sync::atomic::Ordering::Relaxed) + % self.inner.nodes.len(); + loop { + let index = (latest_read_replica_index + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + + 1) + % self.inner.nodes.len(); + + // Looped through all replicas, no connected replica was found. + if index == initial_index { + return self.get_primary_connection(); + } + if index == self.inner.primary_index { + continue; + } + let Some(connection) = self + .inner + .nodes + .get(index) else { + continue; + }; + if connection.is_connected() { + return connection; + } + } + } + } + } + + pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult { log_trace("ClientCMD", "sending command"); - let mut connection = self.inner.primary.get_connection().await?; + let reconnecting_connection = self.get_connection(cmd); + let mut connection = reconnecting_connection.get_connection().await?; let result = connection.send_packed_command(cmd).await; match result { Err(err) if err.is_connection_dropped() => { - self.inner.primary.reconnect(); + reconnecting_connection.reconnect(false); Err(err) } _ => result, @@ -150,12 +194,13 @@ impl ClientCMD { cmd: &redis::Pipeline, offset: usize, count: usize, - ) -> redis::RedisResult> { - let mut connection = self.inner.primary.get_connection().await?; + ) -> RedisResult> { + let reconnecting_connection = self.get_primary_connection(); + let mut connection = reconnecting_connection.get_connection().await?; let result = connection.send_packed_commands(cmd, offset, count).await; match result { Err(err) if err.is_connection_dropped() => { - self.inner.primary.reconnect(); + reconnecting_connection.reconnect(false); Err(err) } _ => result, @@ -190,9 +235,99 @@ impl ClientCMD { .is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal()) { log_debug("ClientCMD", "heartbeat triggered reconnect"); - reconnecting_connection.reconnect(); + reconnecting_connection.reconnect(false); } } }); } } + +async fn get_connection_and_replication_info( + address: &AddressInfo, + retry_strategy: &RetryStrategy, + connection_info: &redis::RedisConnectionInfo, + tls_mode: TlsMode, +) -> Result<(ReconnectingConnection, Value), (ReconnectingConnection, RedisError)> { + let result = ReconnectingConnection::new( + address, + retry_strategy.clone(), + connection_info.clone(), + tls_mode, + ) + .await; + match result { + Ok(reconnecting_connection) => { + let mut multiplexed_connection = match reconnecting_connection.get_connection().await { + Ok(multiplexed_connection) => multiplexed_connection, + Err(err) => { + reconnecting_connection.reconnect(true); + return Err((reconnecting_connection, err)); + } + }; + + match multiplexed_connection + .send_packed_command(redis::cmd("INFO").arg("REPLICATION")) + .await + { + Ok(replication_status) => Ok((reconnecting_connection, replication_status)), + Err(err) => Err((reconnecting_connection, err)), + } + } + Err(tuple) => { + tuple.0.reconnect(true); + Err(tuple) + } + } +} + +// Copied and djusted from redis-rs +fn is_readonly_cmd(cmd: &redis::Cmd) -> bool { + matches!( + cmd.args_iter().next(), + // @admin + Some(redis::Arg::Simple(b"LASTSAVE")) | + // @bitmap + Some(redis::Arg::Simple(b"BITCOUNT")) | Some(redis::Arg::Simple(b"BITFIELD_RO")) | Some(redis::Arg::Simple(b"BITPOS")) | Some(redis::Arg::Simple(b"GETBIT")) | + // @connection + Some(redis::Arg::Simple(b"CLIENT")) | Some(redis::Arg::Simple(b"ECHO")) | + // @geo + Some(redis::Arg::Simple(b"GEODIST")) | Some(redis::Arg::Simple(b"GEOHASH")) | Some(redis::Arg::Simple(b"GEOPOS")) | Some(redis::Arg::Simple(b"GEORADIUSBYMEMBER_RO")) | Some(redis::Arg::Simple(b"GEORADIUS_RO")) | Some(redis::Arg::Simple(b"GEOSEARCH")) | + // @hash + Some(redis::Arg::Simple(b"HEXISTS")) | Some(redis::Arg::Simple(b"HGET")) | Some(redis::Arg::Simple(b"HGETALL")) | Some(redis::Arg::Simple(b"HKEYS")) | Some(redis::Arg::Simple(b"HLEN")) | Some(redis::Arg::Simple(b"HMGET")) | Some(redis::Arg::Simple(b"HRANDFIELD")) | Some(redis::Arg::Simple(b"HSCAN")) | Some(redis::Arg::Simple(b"HSTRLEN")) | Some(redis::Arg::Simple(b"HVALS")) | + // @hyperloglog + Some(redis::Arg::Simple(b"PFCOUNT")) | + // @keyspace + Some(redis::Arg::Simple(b"DBSIZE")) | Some(redis::Arg::Simple(b"DUMP")) | Some(redis::Arg::Simple(b"EXISTS")) | Some(redis::Arg::Simple(b"EXPIRETIME")) | Some(redis::Arg::Simple(b"KEYS")) | Some(redis::Arg::Simple(b"OBJECT")) | Some(redis::Arg::Simple(b"PEXPIRETIME")) | Some(redis::Arg::Simple(b"PTTL")) | Some(redis::Arg::Simple(b"RANDOMKEY")) | Some(redis::Arg::Simple(b"SCAN")) | Some(redis::Arg::Simple(b"TOUCH")) | Some(redis::Arg::Simple(b"TTL")) | Some(redis::Arg::Simple(b"TYPE")) | + // @list + Some(redis::Arg::Simple(b"LINDEX")) | Some(redis::Arg::Simple(b"LLEN")) | Some(redis::Arg::Simple(b"LPOS")) | Some(redis::Arg::Simple(b"LRANGE")) | Some(redis::Arg::Simple(b"SORT_RO")) | + // @scripting + Some(redis::Arg::Simple(b"EVALSHA_RO")) | Some(redis::Arg::Simple(b"EVAL_RO")) | Some(redis::Arg::Simple(b"FCALL_RO")) | + // @set + Some(redis::Arg::Simple(b"SCARD")) | Some(redis::Arg::Simple(b"SDIFF")) | Some(redis::Arg::Simple(b"SINTER")) | Some(redis::Arg::Simple(b"SINTERCARD")) | Some(redis::Arg::Simple(b"SISMEMBER")) | Some(redis::Arg::Simple(b"SMEMBERS")) | Some(redis::Arg::Simple(b"SMISMEMBER")) | Some(redis::Arg::Simple(b"SRANDMEMBER")) | Some(redis::Arg::Simple(b"SSCAN")) | Some(redis::Arg::Simple(b"SUNION")) | + // @sortedset + Some(redis::Arg::Simple(b"ZCARD")) | Some(redis::Arg::Simple(b"ZCOUNT")) | Some(redis::Arg::Simple(b"ZDIFF")) | Some(redis::Arg::Simple(b"ZINTER")) | Some(redis::Arg::Simple(b"ZINTERCARD")) | Some(redis::Arg::Simple(b"ZLEXCOUNT")) | Some(redis::Arg::Simple(b"ZMSCORE")) | Some(redis::Arg::Simple(b"ZRANDMEMBER")) | Some(redis::Arg::Simple(b"ZRANGE")) | Some(redis::Arg::Simple(b"ZRANGEBYLEX")) | Some(redis::Arg::Simple(b"ZRANGEBYSCORE")) | Some(redis::Arg::Simple(b"ZRANK")) | Some(redis::Arg::Simple(b"ZREVRANGE")) | Some(redis::Arg::Simple(b"ZREVRANGEBYLEX")) | Some(redis::Arg::Simple(b"ZREVRANGEBYSCORE")) | Some(redis::Arg::Simple(b"ZREVRANK")) | Some(redis::Arg::Simple(b"ZSCAN")) | Some(redis::Arg::Simple(b"ZSCORE")) | Some(redis::Arg::Simple(b"ZUNION")) | + // @stream + Some(redis::Arg::Simple(b"XINFO")) | Some(redis::Arg::Simple(b"XLEN")) | Some(redis::Arg::Simple(b"XPENDING")) | Some(redis::Arg::Simple(b"XRANGE")) | Some(redis::Arg::Simple(b"XREAD")) | Some(redis::Arg::Simple(b"XREVRANGE")) | + // @string + Some(redis::Arg::Simple(b"GET")) | Some(redis::Arg::Simple(b"GETRANGE")) | Some(redis::Arg::Simple(b"LCS")) | Some(redis::Arg::Simple(b"MGET")) | Some(redis::Arg::Simple(b"STRALGO")) | Some(redis::Arg::Simple(b"STRLEN")) | Some(redis::Arg::Simple(b"SUBSTR")) + ) +} + +fn get_read_from_replica_strategy( + read_from_replica_strategy: &EnumOrUnknown, +) -> ReadFromReplicaStrategy { + match read_from_replica_strategy + .enum_value_or(crate::connection_request::ReadFromReplicaStrategy::AlwaysFromPrimary) + { + crate::connection_request::ReadFromReplicaStrategy::AlwaysFromPrimary => { + ReadFromReplicaStrategy::AlwaysFromPrimary + } + crate::connection_request::ReadFromReplicaStrategy::RoundRobin => { + ReadFromReplicaStrategy::RoundRobin { + latest_read_replica_index: Default::default(), + } + } + crate::connection_request::ReadFromReplicaStrategy::LowestLatency => todo!(), + crate::connection_request::ReadFromReplicaStrategy::AZAffinity => todo!(), + } +} diff --git a/babushka-core/src/client/reconnecting_connection.rs b/babushka-core/src/client/reconnecting_connection.rs index 69d1d59137..73e011c23b 100644 --- a/babushka-core/src/client/reconnecting_connection.rs +++ b/babushka-core/src/client/reconnecting_connection.rs @@ -1,7 +1,7 @@ use crate::connection_request::{AddressInfo, TlsMode}; use crate::retry_strategies::RetryStrategy; use futures_intrusive::sync::ManualResetEvent; -use logger_core::{log_debug, log_trace}; +use logger_core::{log_debug, log_trace, log_warn}; use redis::aio::MultiplexedConnection; use redis::{RedisConnectionInfo, RedisError, RedisResult}; use std::sync::atomic::{AtomicBool, Ordering}; @@ -11,7 +11,7 @@ use std::time::Duration; use tokio::task; use tokio_retry::Retry; -use super::{get_connection_info, run_with_timeout, DEFAULT_CONNECTION_ATTEMPT_TIMEOUT}; +use super::{run_with_timeout, DEFAULT_CONNECTION_ATTEMPT_TIMEOUT}; /// The object that is used in order to recreate a connection after a disconnect. struct ConnectionBackend { @@ -49,35 +49,67 @@ async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult RedisResult { +) -> Result { let client = &connection_backend.connection_info; - let action = || { - log_debug("connection creation", "Creating multiplexed connection"); - get_multiplexed_connection(client) - }; - - let connection = Retry::spawn(retry_strategy.get_iterator(), action).await?; - Ok(ReconnectingConnection { - inner: Arc::new(InnerReconnectingConnection { - state: Mutex::new(ConnectionState::Connected(connection)), - backend: connection_backend, - }), - }) + let action = || get_multiplexed_connection(client); + + match Retry::spawn(retry_strategy.get_iterator(), action).await { + Ok(connection) => { + log_debug( + "connection creation", + format!( + "Connection to {} created", + connection_backend + .connection_info + .get_connection_info() + .addr + ), + ); + Ok(ReconnectingConnection { + inner: Arc::new(InnerReconnectingConnection { + state: Mutex::new(ConnectionState::Connected(connection)), + backend: connection_backend, + }), + }) + } + Err(err) => { + log_warn( + "connection creation", + format!( + "Failed connecting to {}, due to {err}", + connection_backend + .connection_info + .get_connection_info() + .addr + ), + ); + Err(( + ReconnectingConnection { + inner: Arc::new(InnerReconnectingConnection { + state: Mutex::new(ConnectionState::Reconnecting), + backend: connection_backend, + }), + }, + err, + )) + } + } } fn get_client( address: &AddressInfo, tls_mode: TlsMode, redis_connection_info: redis::RedisConnectionInfo, -) -> RedisResult { - redis::Client::open(get_connection_info( +) -> redis::Client { + redis::Client::open(super::get_connection_info( address, tls_mode, redis_connection_info, )) + .unwrap() // can unwrap, because [open] fails only on trying to convert input to ConnectionInfo, and we pass ConnectionInfo. } /// This iterator isn't exposed to users, and can't be configured. @@ -98,23 +130,19 @@ impl ReconnectingConnection { connection_retry_strategy: RetryStrategy, redis_connection_info: RedisConnectionInfo, tls_mode: TlsMode, - ) -> RedisResult { + ) -> Result { log_debug( "connection creation", format!("Attempting connection to {address}"), ); - let client = ConnectionBackend { - connection_info: get_client(address, tls_mode, redis_connection_info)?, + let connection_info = get_client(address, tls_mode, redis_connection_info); + let backend = ConnectionBackend { + connection_info, connection_available_signal: ManualResetEvent::new(true), client_dropped_flagged: AtomicBool::new(false), }; - let connection = try_create_connection(client, connection_retry_strategy).await?; - log_debug( - "connection creation", - format!("Connection to {address} created"), - ); - Ok(connection) + create_connection(backend, connection_retry_strategy).await } pub(super) fn is_dropped(&self) -> bool { @@ -149,8 +177,8 @@ impl ReconnectingConnection { } } - pub(super) fn reconnect(&self) { - { + pub(super) fn reconnect(&self, force_reconnect: bool) { + if !force_reconnect { let mut guard = self.inner.state.lock().unwrap(); match &*guard { ConnectionState::Connected(_) => { @@ -179,7 +207,6 @@ impl ReconnectingConnection { // Client was dropped, reconnection attempts can stop return; } - log_debug("connection creation", "Creating multiplexed connection"); match get_multiplexed_connection(client).await { Ok(connection) => { { @@ -199,4 +226,11 @@ impl ReconnectingConnection { } }); } + + pub fn is_connected(&self) -> bool { + !matches!( + *self.inner.state.lock().unwrap(), + ConnectionState::Reconnecting + ) + } } diff --git a/babushka-core/tests/test_client_cmd.rs b/babushka-core/tests/test_client_cmd.rs index d118417efd..6ed6ffb173 100644 --- a/babushka-core/tests/test_client_cmd.rs +++ b/babushka-core/tests/test_client_cmd.rs @@ -2,7 +2,10 @@ mod utilities; #[cfg(test)] mod client_cmd_tests { + use crate::utilities::mocks::{Mock, ServerMock}; + use super::*; + use babushka::{client::ClientCMD, connection_request::ReadFromReplicaStrategy}; use redis::Value; use rstest::rstest; use std::time::Duration; @@ -63,7 +66,6 @@ mod client_cmd_tests { let mut test_basics = setup_test_basics(use_tls).await; let server = test_basics.server; let address = server.get_client_addr(); - println!("dropping server"); drop(server); // we use another thread, so that the creation of the server won't block the client work. @@ -91,4 +93,152 @@ mod client_cmd_tests { assert_eq!(get_result, Value::Nil); }); } + + fn create_primary_mock_with_replicas(replica_count: usize) -> Vec { + let mut listeners: Vec = (0..replica_count + 1) + .map(|_| get_listener_on_available_port()) + .collect(); + let mut primary_responses = std::collections::HashMap::new(); + primary_responses.insert( + "*2\r\n$4\r\nINFO\r\n$11\r\nREPLICATION\r\n".to_string(), + Value::Data(b"role:master\r\nconnected_slaves:3\r\n".to_vec()), + ); + let primary = ServerMock::new_with_listener(primary_responses, listeners.pop().unwrap()); + let mut mocks = vec![primary]; + let mut replica_responses = std::collections::HashMap::new(); + replica_responses.insert( + "*2\r\n$4\r\nINFO\r\n$11\r\nREPLICATION\r\n".to_string(), + Value::Data(b"role:slave\r\n".to_vec()), + ); + mocks.extend( + listeners + .into_iter() + .map(|listener| ServerMock::new_with_listener(replica_responses.clone(), listener)), + ); + mocks + } + + struct ReadFromReplicaTestConfig { + read_from_replica_strategy: ReadFromReplicaStrategy, + expected_primary_reads: u16, + expected_replica_reads: Vec, + number_of_missing_replicas: usize, + number_of_replicas_dropped_after_connection: usize, + number_of_requests_sent: usize, + } + + impl Default for ReadFromReplicaTestConfig { + fn default() -> Self { + Self { + read_from_replica_strategy: ReadFromReplicaStrategy::AlwaysFromPrimary, + expected_primary_reads: 3, + expected_replica_reads: vec![0, 0, 0], + number_of_missing_replicas: 0, + number_of_replicas_dropped_after_connection: 0, + number_of_requests_sent: 3, + } + } + } + + fn test_read_from_replica(config: ReadFromReplicaTestConfig) { + let mut mocks = create_primary_mock_with_replicas(3 - config.number_of_missing_replicas); + let mut cmd = redis::cmd("GET"); + cmd.arg("foo"); + + for mock in mocks.iter() { + for _ in 0..3 { + mock.add_response(&cmd, "$-1\r\n".to_string()); + } + } + + let mut addresses: Vec = + mocks.iter().flat_map(|mock| mock.get_addresses()).collect(); + + for i in 4 + - config.number_of_missing_replicas + - config.number_of_replicas_dropped_after_connection..4 + { + addresses.push(redis::ConnectionAddr::Tcp( + "foo".to_string(), + 6379 + i as u16, + )); + } + let mut connection_request = + create_connection_request(addresses.as_slice(), &Default::default()); + connection_request.read_from_replica_strategy = config.read_from_replica_strategy.into(); + + block_on_all(async { + let mut client = ClientCMD::create_client(connection_request).await.unwrap(); + mocks.drain(1..config.number_of_replicas_dropped_after_connection + 1); + for _ in 0..config.number_of_requests_sent { + let _ = client.send_packed_command(&cmd).await; + } + }); + + assert_eq!( + mocks[0].get_number_of_received_commands(), + config.expected_primary_reads + ); + let mut replica_reads: Vec<_> = mocks + .iter() + .skip(1) + .map(|mock| mock.get_number_of_received_commands()) + .collect(); + replica_reads.sort(); + assert_eq!(config.expected_replica_reads, replica_reads); + } + + #[rstest] + #[timeout(SHORT_CMD_TEST_TIMEOUT)] + fn test_read_from_replica_always_read_from_primary() { + test_read_from_replica(ReadFromReplicaTestConfig::default()); + } + + #[rstest] + #[timeout(SHORT_CMD_TEST_TIMEOUT)] + fn test_read_from_replica_round_robin() { + test_read_from_replica(ReadFromReplicaTestConfig { + read_from_replica_strategy: ReadFromReplicaStrategy::RoundRobin, + expected_primary_reads: 0, + expected_replica_reads: vec![1, 1, 1], + ..Default::default() + }); + } + + #[rstest] + #[timeout(SHORT_CMD_TEST_TIMEOUT)] + fn test_read_from_replica_round_robin_skip_disconnected_replicas() { + test_read_from_replica(ReadFromReplicaTestConfig { + read_from_replica_strategy: ReadFromReplicaStrategy::RoundRobin, + expected_primary_reads: 0, + expected_replica_reads: vec![1, 2], + number_of_missing_replicas: 1, + ..Default::default() + }); + } + + #[rstest] + #[timeout(SHORT_CMD_TEST_TIMEOUT)] + fn test_read_from_replica_round_robin_read_from_primary_if_no_replica_is_connected() { + test_read_from_replica(ReadFromReplicaTestConfig { + read_from_replica_strategy: ReadFromReplicaStrategy::RoundRobin, + expected_primary_reads: 3, + expected_replica_reads: vec![], + number_of_missing_replicas: 3, + ..Default::default() + }); + } + + #[rstest] + #[timeout(SHORT_CMD_TEST_TIMEOUT)] + fn test_read_from_replica_round_robin_do_not_read_from_disconnected_replica() { + test_read_from_replica(ReadFromReplicaTestConfig { + read_from_replica_strategy: ReadFromReplicaStrategy::RoundRobin, + expected_primary_reads: 0, + expected_replica_reads: vec![2, 3], + number_of_replicas_dropped_after_connection: 1, + number_of_requests_sent: 6, + ..Default::default() + }); + } } diff --git a/babushka-core/tests/test_socket_listener.rs b/babushka-core/tests/test_socket_listener.rs index ccd587eca7..96c7ae1af9 100644 --- a/babushka-core/tests/test_socket_listener.rs +++ b/babushka-core/tests/test_socket_listener.rs @@ -24,7 +24,6 @@ mod socket_listener { use redis::{Cmd, ConnectionAddr, Value}; use redis_request::{RedisRequest, RequestType}; use rstest::rstest; - use std::collections::HashMap; use std::mem::size_of; use tokio::{net::UnixListener, runtime::Builder}; @@ -325,7 +324,12 @@ mod socket_listener { } fn setup_mocked_test_basics(socket_path: Option) -> TestBasicsWithMock { - let server_mock = ServerMock::new(HashMap::new()); + let mut responses = std::collections::HashMap::new(); + responses.insert( + "*2\r\n$4\r\nINFO\r\n$11\r\nREPLICATION\r\n".to_string(), + Value::Data(b"role:master\r\nconnected_slaves:0\r\n".to_vec()), + ); + let server_mock = ServerMock::new(responses); let addresses = server_mock.get_addresses(); let socket = setup_socket( false, @@ -625,7 +629,6 @@ mod socket_listener { write_get(&mut buffer, CALLBACK_INDEX, key.as_str(), use_arg_pointer); test_basics.socket.write_all(&buffer).unwrap(); - println!("test reading from socket"); let _size = read_from_socket(&mut buffer, &mut test_basics.socket); assert_null_response(&buffer, CALLBACK_INDEX); } diff --git a/babushka-core/tests/utilities/mocks.rs b/babushka-core/tests/utilities/mocks.rs index 4414aedf92..7641517080 100644 --- a/babushka-core/tests/utilities/mocks.rs +++ b/babushka-core/tests/utilities/mocks.rs @@ -2,6 +2,7 @@ use futures_intrusive::sync::ManualResetEvent; use redis::{Cmd, ConnectionAddr, Value}; use std::collections::HashMap; use std::io; +use std::net::TcpListener; use std::str::from_utf8; use std::sync::{ atomic::{AtomicU16, Ordering}, @@ -95,6 +96,13 @@ pub trait Mock { impl ServerMock { pub fn new(constant_responses: HashMap) -> Self { let listener = super::get_listener_on_available_port(); + Self::new_with_listener(constant_responses, listener) + } + + pub fn new_with_listener( + constant_responses: HashMap, + listener: TcpListener, + ) -> Self { let (request_sender, mut receiver) = tokio::sync::mpsc::unbounded_channel(); let received_commands = Arc::new(AtomicU16::new(0)); let received_commands_clone = received_commands.clone();