Skip to content

Commit

Permalink
Add CMD read from replica.
Browse files Browse the repository at this point in the history
  • Loading branch information
shachlanAmazon committed Aug 2, 2023
1 parent 8d711c6 commit 664da0e
Show file tree
Hide file tree
Showing 5 changed files with 437 additions and 107 deletions.
281 changes: 208 additions & 73 deletions babushka-core/src/client/client_cmd.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
use super::get_redis_connection_info;
use super::reconnecting_connection::ReconnectingConnection;
use crate::connection_request::{ConnectionRequest, TlsMode};
use crate::connection_request::{AddressInfo, ConnectionRequest, TlsMode};
use crate::retry_strategies::RetryStrategy;
use futures::{stream, StreamExt};
use logger_core::{log_debug, log_trace};
use redis::RedisError;
use logger_core::{log_debug, log_trace, log_warn};
use protobuf::EnumOrUnknown;
use redis::{RedisError, RedisResult, Value};
use std::sync::Arc;
use tokio::task;

enum ReadFromReplicaStrategy {
AlwaysFromPrimary,
RoundRobin {
latest_read_replica_index: Arc<std::sync::atomic::AtomicUsize>,
},
}

struct DropWrapper {
/// Connection to the primary node in the client.
primary: ReconnectingConnection,
replicas: Vec<ReconnectingConnection>,
primary_index: usize,
nodes: Vec<ReconnectingConnection>,
read_from_replica_strategy: ReadFromReplicaStrategy,
}

impl Drop for DropWrapper {
fn drop(&mut self) {
self.primary.mark_as_dropped();
for node in self.nodes.iter() {
node.mark_as_dropped();
}
}
}

Expand All @@ -27,7 +38,6 @@ pub struct ClientCMD {

pub enum ClientCMDConnectionError {
NoAddressesProvided,
NoPrimaryFound,
FailedConnection(Vec<(String, RedisError)>),
}

Expand All @@ -37,9 +47,6 @@ impl std::fmt::Debug for ClientCMDConnectionError {
ClientCMDConnectionError::NoAddressesProvided => {
write!(f, "No addresses provided")
}
ClientCMDConnectionError::NoPrimaryFound => {
write!(f, "No primary node found")
}
ClientCMDConnectionError::FailedConnection(errs) => {
writeln!(f, "Received errors:")?;
for (address, error) in errs {
Expand All @@ -63,82 +70,119 @@ impl ClientCMD {
get_redis_connection_info(connection_request.authentication_info.0);

let tls_mode = connection_request.tls_mode.enum_value_or(TlsMode::NoTls);
let connections = stream::iter(connection_request.addresses.iter())
let node_count = connection_request.addresses.len();
let mut stream = stream::iter(connection_request.addresses.iter())
.map(|address| async {
(
format!("{}:{}", address.host, address.port),
async {
let reconnecting_connection = ReconnectingConnection::new(
address,
retry_strategy.clone(),
redis_connection_info.clone(),
tls_mode,
)
.await?;
let mut multiplexed_connection =
reconnecting_connection.get_connection().await?;
let replication_status = multiplexed_connection
.send_packed_command(redis::cmd("INFO").arg("REPLICATION"))
.await?;
Ok((reconnecting_connection, replication_status))
}
.await,
get_connection_and_replication_info(
address,
&retry_strategy,
&redis_connection_info,
tls_mode,
)
.await
.map_err(|err| (format!("{}:{}", address.host, address.port), err))
})
.buffer_unordered(connection_request.addresses.len())
.collect::<Vec<_>>()
.await;

if connections.iter().any(|(_, result)| result.is_err()) {
let addresses_and_errors: Vec<(String, RedisError)> = connections
.into_iter()
.filter_map(|(address, result)| result.err().map(|err| (address, err)))
.collect();
return Err(ClientCMDConnectionError::FailedConnection(
addresses_and_errors,
));
.buffer_unordered(node_count);

let mut nodes = Vec::with_capacity(node_count);
let mut addresses_and_errors = Vec::with_capacity(node_count);
let mut primary_index = None;
while let Some(result) = stream.next().await {
match result {
Ok((connection, replication_status)) => {
nodes.push(connection);
if primary_index.is_none()
&& redis::from_redis_value::<String>(&replication_status)
.is_ok_and(|val| val.contains("role:master"))
{
primary_index = Some(nodes.len() - 1);
}
}
Err((address, (connection, err))) => {
nodes.push(connection);
addresses_and_errors.push((address, err));
}
}
}
let results: Vec<(ReconnectingConnection, redis::Value)> = connections
.into_iter()
.map(|(_, result)| result.unwrap())
.collect();
let Some(primary_index) = results.iter().position(|(_, replication_status)|
redis::from_redis_value::<String>(replication_status).ok().and_then(|val|if val.contains("role:master") { Some(())} else {None}).is_some()
) else {
return Err(ClientCMDConnectionError::NoPrimaryFound);

let Some(primary_index) = primary_index else {
return Err(ClientCMDConnectionError::FailedConnection(addresses_and_errors));
};
let mut connections: Vec<ReconnectingConnection> = results
.into_iter()
.map(|(connection, _)| connection)
.collect();
let Some(primary) = connections
.drain(primary_index..primary_index+1)
.next() else {
return Err(ClientCMDConnectionError::NoPrimaryFound);
};
if !addresses_and_errors.is_empty() {
log_warn(
"client creation",
format!(
"Failed to connect to {addresses_and_errors:?}, will attempt to reconnect."
),
);
}
let read_from_replica_strategy =
get_read_from_replica_strategy(&connection_request.read_from_replica_strategy);

Self::start_heartbeat(primary.clone());
for connection in connections.iter() {
Self::start_heartbeat(connection.clone());
for node in nodes.iter() {
Self::start_heartbeat(node.clone());
}

Ok(Self {
inner: Arc::new(DropWrapper {
primary,
replicas: connections,
primary_index,
nodes,
read_from_replica_strategy,
}),
})
}

pub async fn send_packed_command(
&mut self,
cmd: &redis::Cmd,
) -> redis::RedisResult<redis::Value> {
fn get_primary_connection(&self) -> &ReconnectingConnection {
self.inner.nodes.get(self.inner.primary_index).unwrap()
}

fn get_connection(&self, cmd: &redis::Cmd) -> &ReconnectingConnection {
if !is_readonly_cmd(cmd) || self.inner.nodes.len() == 1 {
return self.get_primary_connection();
}
match &self.inner.read_from_replica_strategy {
ReadFromReplicaStrategy::AlwaysFromPrimary => self.get_primary_connection(),
ReadFromReplicaStrategy::RoundRobin {
latest_read_replica_index,
} => {
let initial_index = latest_read_replica_index
.load(std::sync::atomic::Ordering::Relaxed)
% self.inner.nodes.len();
loop {
let index = (latest_read_replica_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1)
% self.inner.nodes.len();

// Looped through all replicas, no connected replica was found.
if index == initial_index {
return self.get_primary_connection();
}
if index == self.inner.primary_index {
continue;
}
let Some(connection) = self
.inner
.nodes
.get(index) else {
continue;
};
if connection.is_connected() {
return connection;
}
}
}
}
}

pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult<Value> {
log_trace("ClientCMD", "sending command");
let mut connection = self.inner.primary.get_connection().await?;
let reconnecting_connection = self.get_connection(cmd);
let mut connection = reconnecting_connection.get_connection().await?;
let result = connection.send_packed_command(cmd).await;
match result {
Err(err) if err.is_connection_dropped() => {
self.inner.primary.reconnect();
reconnecting_connection.reconnect(false);
Err(err)
}
_ => result,
Expand All @@ -150,12 +194,13 @@ impl ClientCMD {
cmd: &redis::Pipeline,
offset: usize,
count: usize,
) -> redis::RedisResult<Vec<redis::Value>> {
let mut connection = self.inner.primary.get_connection().await?;
) -> RedisResult<Vec<Value>> {
let reconnecting_connection = self.get_primary_connection();
let mut connection = reconnecting_connection.get_connection().await?;
let result = connection.send_packed_commands(cmd, offset, count).await;
match result {
Err(err) if err.is_connection_dropped() => {
self.inner.primary.reconnect();
reconnecting_connection.reconnect(false);
Err(err)
}
_ => result,
Expand Down Expand Up @@ -190,9 +235,99 @@ impl ClientCMD {
.is_err_and(|err| err.is_connection_dropped() || err.is_connection_refusal())
{
log_debug("ClientCMD", "heartbeat triggered reconnect");
reconnecting_connection.reconnect();
reconnecting_connection.reconnect(false);
}
}
});
}
}

async fn get_connection_and_replication_info(
address: &AddressInfo,
retry_strategy: &RetryStrategy,
connection_info: &redis::RedisConnectionInfo,
tls_mode: TlsMode,
) -> Result<(ReconnectingConnection, Value), (ReconnectingConnection, RedisError)> {
let result = ReconnectingConnection::new(
address,
retry_strategy.clone(),
connection_info.clone(),
tls_mode,
)
.await;
match result {
Ok(reconnecting_connection) => {
let mut multiplexed_connection = match reconnecting_connection.get_connection().await {
Ok(multiplexed_connection) => multiplexed_connection,
Err(err) => {
reconnecting_connection.reconnect(true);
return Err((reconnecting_connection, err));
}
};

match multiplexed_connection
.send_packed_command(redis::cmd("INFO").arg("REPLICATION"))
.await
{
Ok(replication_status) => Ok((reconnecting_connection, replication_status)),
Err(err) => Err((reconnecting_connection, err)),
}
}
Err(tuple) => {
tuple.0.reconnect(true);
Err(tuple)
}
}
}

// Copied and djusted from redis-rs
fn is_readonly_cmd(cmd: &redis::Cmd) -> bool {
matches!(
cmd.args_iter().next(),
// @admin
Some(redis::Arg::Simple(b"LASTSAVE")) |
// @bitmap
Some(redis::Arg::Simple(b"BITCOUNT")) | Some(redis::Arg::Simple(b"BITFIELD_RO")) | Some(redis::Arg::Simple(b"BITPOS")) | Some(redis::Arg::Simple(b"GETBIT")) |
// @connection
Some(redis::Arg::Simple(b"CLIENT")) | Some(redis::Arg::Simple(b"ECHO")) |
// @geo
Some(redis::Arg::Simple(b"GEODIST")) | Some(redis::Arg::Simple(b"GEOHASH")) | Some(redis::Arg::Simple(b"GEOPOS")) | Some(redis::Arg::Simple(b"GEORADIUSBYMEMBER_RO")) | Some(redis::Arg::Simple(b"GEORADIUS_RO")) | Some(redis::Arg::Simple(b"GEOSEARCH")) |
// @hash
Some(redis::Arg::Simple(b"HEXISTS")) | Some(redis::Arg::Simple(b"HGET")) | Some(redis::Arg::Simple(b"HGETALL")) | Some(redis::Arg::Simple(b"HKEYS")) | Some(redis::Arg::Simple(b"HLEN")) | Some(redis::Arg::Simple(b"HMGET")) | Some(redis::Arg::Simple(b"HRANDFIELD")) | Some(redis::Arg::Simple(b"HSCAN")) | Some(redis::Arg::Simple(b"HSTRLEN")) | Some(redis::Arg::Simple(b"HVALS")) |
// @hyperloglog
Some(redis::Arg::Simple(b"PFCOUNT")) |
// @keyspace
Some(redis::Arg::Simple(b"DBSIZE")) | Some(redis::Arg::Simple(b"DUMP")) | Some(redis::Arg::Simple(b"EXISTS")) | Some(redis::Arg::Simple(b"EXPIRETIME")) | Some(redis::Arg::Simple(b"KEYS")) | Some(redis::Arg::Simple(b"OBJECT")) | Some(redis::Arg::Simple(b"PEXPIRETIME")) | Some(redis::Arg::Simple(b"PTTL")) | Some(redis::Arg::Simple(b"RANDOMKEY")) | Some(redis::Arg::Simple(b"SCAN")) | Some(redis::Arg::Simple(b"TOUCH")) | Some(redis::Arg::Simple(b"TTL")) | Some(redis::Arg::Simple(b"TYPE")) |
// @list
Some(redis::Arg::Simple(b"LINDEX")) | Some(redis::Arg::Simple(b"LLEN")) | Some(redis::Arg::Simple(b"LPOS")) | Some(redis::Arg::Simple(b"LRANGE")) | Some(redis::Arg::Simple(b"SORT_RO")) |
// @scripting
Some(redis::Arg::Simple(b"EVALSHA_RO")) | Some(redis::Arg::Simple(b"EVAL_RO")) | Some(redis::Arg::Simple(b"FCALL_RO")) |
// @set
Some(redis::Arg::Simple(b"SCARD")) | Some(redis::Arg::Simple(b"SDIFF")) | Some(redis::Arg::Simple(b"SINTER")) | Some(redis::Arg::Simple(b"SINTERCARD")) | Some(redis::Arg::Simple(b"SISMEMBER")) | Some(redis::Arg::Simple(b"SMEMBERS")) | Some(redis::Arg::Simple(b"SMISMEMBER")) | Some(redis::Arg::Simple(b"SRANDMEMBER")) | Some(redis::Arg::Simple(b"SSCAN")) | Some(redis::Arg::Simple(b"SUNION")) |
// @sortedset
Some(redis::Arg::Simple(b"ZCARD")) | Some(redis::Arg::Simple(b"ZCOUNT")) | Some(redis::Arg::Simple(b"ZDIFF")) | Some(redis::Arg::Simple(b"ZINTER")) | Some(redis::Arg::Simple(b"ZINTERCARD")) | Some(redis::Arg::Simple(b"ZLEXCOUNT")) | Some(redis::Arg::Simple(b"ZMSCORE")) | Some(redis::Arg::Simple(b"ZRANDMEMBER")) | Some(redis::Arg::Simple(b"ZRANGE")) | Some(redis::Arg::Simple(b"ZRANGEBYLEX")) | Some(redis::Arg::Simple(b"ZRANGEBYSCORE")) | Some(redis::Arg::Simple(b"ZRANK")) | Some(redis::Arg::Simple(b"ZREVRANGE")) | Some(redis::Arg::Simple(b"ZREVRANGEBYLEX")) | Some(redis::Arg::Simple(b"ZREVRANGEBYSCORE")) | Some(redis::Arg::Simple(b"ZREVRANK")) | Some(redis::Arg::Simple(b"ZSCAN")) | Some(redis::Arg::Simple(b"ZSCORE")) | Some(redis::Arg::Simple(b"ZUNION")) |
// @stream
Some(redis::Arg::Simple(b"XINFO")) | Some(redis::Arg::Simple(b"XLEN")) | Some(redis::Arg::Simple(b"XPENDING")) | Some(redis::Arg::Simple(b"XRANGE")) | Some(redis::Arg::Simple(b"XREAD")) | Some(redis::Arg::Simple(b"XREVRANGE")) |
// @string
Some(redis::Arg::Simple(b"GET")) | Some(redis::Arg::Simple(b"GETRANGE")) | Some(redis::Arg::Simple(b"LCS")) | Some(redis::Arg::Simple(b"MGET")) | Some(redis::Arg::Simple(b"STRALGO")) | Some(redis::Arg::Simple(b"STRLEN")) | Some(redis::Arg::Simple(b"SUBSTR"))
)
}

fn get_read_from_replica_strategy(
read_from_replica_strategy: &EnumOrUnknown<crate::connection_request::ReadFromReplicaStrategy>,
) -> ReadFromReplicaStrategy {
match read_from_replica_strategy
.enum_value_or(crate::connection_request::ReadFromReplicaStrategy::AlwaysFromPrimary)
{
crate::connection_request::ReadFromReplicaStrategy::AlwaysFromPrimary => {
ReadFromReplicaStrategy::AlwaysFromPrimary
}
crate::connection_request::ReadFromReplicaStrategy::RoundRobin => {
ReadFromReplicaStrategy::RoundRobin {
latest_read_replica_index: Default::default(),
}
}
crate::connection_request::ReadFromReplicaStrategy::LowestLatency => todo!(),
crate::connection_request::ReadFromReplicaStrategy::AZAffinity => todo!(),
}
}
Loading

0 comments on commit 664da0e

Please sign in to comment.