Skip to content

Commit

Permalink
Pubsub implementation in glide-core and with Python wrapper. Works bo…
Browse files Browse the repository at this point in the history
…th in standalone and cluster modes. Pubsub configuration is provided in client creation params.
  • Loading branch information
ikolomi committed Jun 18, 2024
1 parent a9566a3 commit 18e6a9b
Show file tree
Hide file tree
Showing 29 changed files with 525 additions and 71 deletions.
2 changes: 1 addition & 1 deletion benchmarks/rust/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async fn get_connection(args: &Args) -> Client {
..Default::default()
};

glide_core::client::Client::new(connection_request)
glide_core::client::Client::new(connection_request, None)
.await
.unwrap()
}
Expand Down
2 changes: 1 addition & 1 deletion csharp/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn create_client_internal(
.thread_name("GLIDE for Redis C# thread")
.build()?;
let _runtime_handle = runtime.enter();
let client = runtime.block_on(GlideClient::new(request)).unwrap(); // TODO - handle errors.
let client = runtime.block_on(GlideClient::new(request, None)).unwrap(); // TODO - handle errors.
Ok(Client {
client,
success_callback,
Expand Down
4 changes: 2 additions & 2 deletions glide-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ tokio-retry = "0.3.0"
protobuf = { version= "3", features = ["bytes", "with-bytes"], optional = true }
integer-encoding = { version = "4.0.0", optional = true }
thiserror = "1"
rand = { version = "0.8.5", optional = true }
rand = { version = "0.8.5" }
futures-intrusive = "0.5.0"
directories = { version = "4.0", optional = true }
once_cell = "1.18.0"
arcstr = "1.1.5"
sha1_smol = "1.0.0"

[features]
socket-layer = ["directories", "integer-encoding", "num_cpus", "protobuf", "tokio-util", "bytes", "rand"]
socket-layer = ["directories", "integer-encoding", "num_cpus", "protobuf", "tokio-util", "bytes"]

[dev-dependencies]
rsevents = "0.3.1"
Expand Down
4 changes: 2 additions & 2 deletions glide-core/benches/connections_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn get_connection_info(address: ConnectionAddr) -> redis::ConnectionInfo {
fn multiplexer_benchmark(c: &mut Criterion, address: ConnectionAddr, group: &str) {
benchmark(c, address, "multiplexer", group, |address, runtime| {
let client = redis::Client::open(get_connection_info(address)).unwrap();
runtime.block_on(async { client.get_multiplexed_tokio_connection().await.unwrap() })
runtime.block_on(async { client.get_multiplexed_tokio_connection(None).await.unwrap() })
});
}

Expand Down Expand Up @@ -120,7 +120,7 @@ fn cluster_connection_benchmark(
builder = builder.read_from_replicas();
}
let client = builder.build().unwrap();
client.get_async_connection().await
client.get_async_connection(None).await
})
.unwrap()
});
Expand Down
2 changes: 1 addition & 1 deletion glide-core/benches/memory_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ where
{
let runtime = Builder::new_current_thread().enable_all().build().unwrap();
runtime.block_on(async {
let client = Client::new(create_connection_request().into())
let client = Client::new(create_connection_request().into(), None)
.await
.unwrap();
f(client).await;
Expand Down
29 changes: 23 additions & 6 deletions glide-core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use logger_core::log_info;
use redis::aio::ConnectionLike;
use redis::cluster_async::ClusterConnection;
use redis::cluster_routing::{Routable, RoutingInfo, SingleNodeRoutingInfo};
use redis::{Cmd, ErrorKind, Value};
use redis::{Cmd, ErrorKind, PushInfo, Value};
use redis::{RedisError, RedisResult};
pub use standalone_client::StandaloneClient;
use std::io;
Expand All @@ -21,6 +21,7 @@ use self::value_conversion::{convert_to_expected_type, expected_type_for_cmd, ge
mod reconnecting_connection;
mod standalone_client;
mod value_conversion;
use tokio::sync::mpsc;

pub const HEARTBEAT_SLEEP_DURATION: Duration = Duration::from_secs(1);

Expand All @@ -44,18 +45,21 @@ pub(super) fn get_redis_connection_info(
let protocol = connection_request.protocol.unwrap_or_default();
let db = connection_request.database_id;
let client_name = connection_request.client_name.clone();
let pubsub_subscriptions = connection_request.pubsub_subscriptions.clone();
match &connection_request.authentication_info {
Some(info) => redis::RedisConnectionInfo {
db,
username: info.username.clone(),
password: info.password.clone(),
protocol,
client_name,
pubsub_subscriptions,
},
None => redis::RedisConnectionInfo {
db,
protocol,
client_name,
pubsub_subscriptions,
..Default::default()
},
}
Expand Down Expand Up @@ -373,6 +377,7 @@ fn to_duration(time_in_millis: Option<u32>, default: Duration) -> Duration {

async fn create_cluster_client(
request: ConnectionRequest,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
) -> RedisResult<redis::cluster_async::ClusterConnection> {
// TODO - implement timeout for each connection attempt
let tls_mode = request.tls_mode.unwrap_or_default();
Expand Down Expand Up @@ -410,8 +415,11 @@ async fn create_cluster_client(
};
builder = builder.tls(tls);
}
if let Some(pubsub_subscriptions) = redis_connection_info.pubsub_subscriptions {
builder = builder.pubsub_subscriptions(pubsub_subscriptions);
}
let client = builder.build()?;
client.get_async_connection().await
client.get_async_connection(push_sender).await
}

#[derive(thiserror::Error)]
Expand Down Expand Up @@ -520,13 +528,22 @@ fn sanitized_request_string(request: &ConnectionRequest) -> String {
String::new()
};

let pubsub_subscriptions = request
.pubsub_subscriptions
.as_ref()
.map(|pubsub_subscriptions| format!("\nPubsub subscriptions: {pubsub_subscriptions:?}"))
.unwrap_or_default();

format!(
"\nAddresses: {addresses}{tls_mode}{cluster_mode}{request_timeout}{rfr_strategy}{connection_retry_strategy}{database_id}{protocol}{client_name}{periodic_checks}",
"\nAddresses: {addresses}{tls_mode}{cluster_mode}{request_timeout}{rfr_strategy}{connection_retry_strategy}{database_id}{protocol}{client_name}{periodic_checks}{pubsub_subscriptions}",
)
}

impl Client {
pub async fn new(request: ConnectionRequest) -> Result<Self, ConnectionError> {
pub async fn new(
request: ConnectionRequest,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
) -> Result<Self, ConnectionError> {
const DEFAULT_CLIENT_CREATION_TIMEOUT: Duration = Duration::from_secs(10);

log_info(
Expand All @@ -536,13 +553,13 @@ impl Client {
let request_timeout = to_duration(request.request_timeout, DEFAULT_RESPONSE_TIMEOUT);
tokio::time::timeout(DEFAULT_CLIENT_CREATION_TIMEOUT, async move {
let internal_client = if request.cluster_mode_enabled {
let client = create_cluster_client(request)
let client = create_cluster_client(request, push_sender)
.await
.map_err(ConnectionError::Cluster)?;
ClientWrapper::Cluster { client }
} else {
ClientWrapper::Standalone(
StandaloneClient::create_client(request)
StandaloneClient::create_client(request, push_sender)
.await
.map_err(ConnectionError::Standalone)?,
)
Expand Down
22 changes: 16 additions & 6 deletions glide-core/src/client/reconnecting_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use crate::retry_strategies::RetryStrategy;
use futures_intrusive::sync::ManualResetEvent;
use logger_core::{log_debug, log_trace, log_warn};
use redis::aio::MultiplexedConnection;
use redis::{RedisConnectionInfo, RedisError, RedisResult};
use redis::{PushInfo, RedisConnectionInfo, RedisError, RedisResult};
use std::fmt;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::task;
use tokio_retry::Retry;

Expand Down Expand Up @@ -45,6 +46,7 @@ struct InnerReconnectingConnection {
#[derive(Clone)]
pub(super) struct ReconnectingConnection {
inner: Arc<InnerReconnectingConnection>,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
}

impl fmt::Debug for ReconnectingConnection {
Expand All @@ -53,20 +55,24 @@ impl fmt::Debug for ReconnectingConnection {
}
}

async fn get_multiplexed_connection(client: &redis::Client) -> RedisResult<MultiplexedConnection> {
async fn get_multiplexed_connection(
client: &redis::Client,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
) -> RedisResult<MultiplexedConnection> {
run_with_timeout(
Some(DEFAULT_CONNECTION_ATTEMPT_TIMEOUT),
client.get_multiplexed_async_connection(),
client.get_multiplexed_async_connection(push_sender),
)
.await
}

async fn create_connection(
connection_backend: ConnectionBackend,
retry_strategy: RetryStrategy,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
) -> Result<ReconnectingConnection, (ReconnectingConnection, RedisError)> {
let client = &connection_backend.connection_info;
let action = || get_multiplexed_connection(client);
let action = || get_multiplexed_connection(client, push_sender.clone());

match Retry::spawn(retry_strategy.get_iterator(), action).await {
Ok(connection) => {
Expand All @@ -85,6 +91,7 @@ async fn create_connection(
state: Mutex::new(ConnectionState::Connected(connection)),
backend: connection_backend,
}),
push_sender,
})
}
Err(err) => {
Expand All @@ -103,6 +110,7 @@ async fn create_connection(
state: Mutex::new(ConnectionState::InitializedDisconnected),
backend: connection_backend,
}),
push_sender,
};
connection.reconnect();
Err((connection, err))
Expand Down Expand Up @@ -141,6 +149,7 @@ impl ReconnectingConnection {
connection_retry_strategy: RetryStrategy,
redis_connection_info: RedisConnectionInfo,
tls_mode: TlsMode,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
) -> Result<ReconnectingConnection, (ReconnectingConnection, RedisError)> {
log_debug(
"connection creation",
Expand All @@ -153,7 +162,7 @@ impl ReconnectingConnection {
connection_available_signal: ManualResetEvent::new(true),
client_dropped_flagged: AtomicBool::new(false),
};
create_connection(backend, connection_retry_strategy).await
create_connection(backend, connection_retry_strategy, push_sender).await
}

fn node_address(&self) -> String {
Expand Down Expand Up @@ -211,6 +220,7 @@ impl ReconnectingConnection {
log_debug("reconnect", "starting");

let connection_clone = self.clone();
let push_sender = self.push_sender.clone();
// The reconnect task is spawned instead of awaited here, so that the reconnect attempt will continue in the
// background, regardless of whether the calling task is dropped or not.
task::spawn(async move {
Expand All @@ -224,7 +234,7 @@ impl ReconnectingConnection {
// Client was dropped, reconnection attempts can stop
return;
}
match get_multiplexed_connection(client).await {
match get_multiplexed_connection(client, push_sender.clone()).await {
Ok(mut connection) => {
if connection
.send_packed_command(&redis::cmd("PING"))
Expand Down
21 changes: 18 additions & 3 deletions glide-core/src/client/standalone_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ use futures::{future, stream, StreamExt};
#[cfg(standalone_heartbeat)]
use logger_core::log_debug;
use logger_core::log_warn;
use rand::Rng;
use redis::cluster_routing::{self, is_readonly_cmd, ResponsePolicy, Routable, RoutingInfo};
use redis::{RedisError, RedisResult, Value};
use redis::{PushInfo, RedisError, RedisResult, Value};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use tokio::sync::mpsc;
#[cfg(standalone_heartbeat)]
use tokio::task;

Expand Down Expand Up @@ -96,22 +98,33 @@ impl std::fmt::Debug for StandaloneClientConnectionError {
impl StandaloneClient {
pub async fn create_client(
connection_request: ConnectionRequest,
push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
) -> Result<Self, StandaloneClientConnectionError> {
if connection_request.addresses.is_empty() {
return Err(StandaloneClientConnectionError::NoAddressesProvided);
}
let redis_connection_info = get_redis_connection_info(&connection_request);
let mut redis_connection_info = get_redis_connection_info(&connection_request);
let pubsub_connection_info = redis_connection_info.clone();
redis_connection_info.pubsub_subscriptions = None;
let retry_strategy = RetryStrategy::new(connection_request.connection_retry_strategy);

let tls_mode = connection_request.tls_mode;
let node_count = connection_request.addresses.len();
// randomize pubsub nodes, maybe a batter option is to always use the primary
let pubsub_node_index = rand::thread_rng().gen_range(0..node_count);
let pubsub_addr = &connection_request.addresses[pubsub_node_index];
let mut stream = stream::iter(connection_request.addresses.iter())
.map(|address| async {
get_connection_and_replication_info(
address,
&retry_strategy,
&redis_connection_info,
if address.to_string() != pubsub_addr.to_string() {
&redis_connection_info
} else {
&pubsub_connection_info
},
tls_mode.unwrap_or(TlsMode::NoTls),
&push_sender,
)
.await
.map_err(|err| (format!("{}:{}", address.host, address.port), err))
Expand Down Expand Up @@ -392,12 +405,14 @@ async fn get_connection_and_replication_info(
retry_strategy: &RetryStrategy,
connection_info: &redis::RedisConnectionInfo,
tls_mode: TlsMode,
push_sender: &Option<mpsc::UnboundedSender<PushInfo>>,
) -> Result<(ReconnectingConnection, Value), (ReconnectingConnection, RedisError)> {
let result = ReconnectingConnection::new(
address,
retry_strategy.clone(),
connection_info.clone(),
tls_mode,
push_sender.clone(),
)
.await;
let reconnecting_connection = match result {
Expand Down
37 changes: 37 additions & 0 deletions glide-core/src/client/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
* Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0
*/

use logger_core::log_warn;
use std::collections::HashSet;
use std::time::Duration;

#[cfg(feature = "socket-layer")]
Expand All @@ -20,6 +22,7 @@ pub struct ConnectionRequest {
pub request_timeout: Option<u32>,
pub connection_retry_strategy: Option<ConnectionRetryStrategy>,
pub periodic_checks: Option<PeriodicCheck>,
pub pubsub_subscriptions: Option<redis::PubSubSubscriptionInfo>,
}

pub struct AuthenticationInfo {
Expand Down Expand Up @@ -150,6 +153,39 @@ impl From<protobuf::ConnectionRequest> for ConnectionRequest {
PeriodicCheck::Disabled
}
});
let mut pubsub_subscriptions: Option<redis::PubSubSubscriptionInfo> = None;
if let Some(protobuf_pubsub) = value.pubsub_subscriptions.0 {
let mut redis_pubsub = redis::PubSubSubscriptionInfo::new();
for (pubsub_type, channels_patterns) in
protobuf_pubsub.channels_or_patterns_by_type.iter()
{
let kind = match *pubsub_type {
0 => redis::PubSubSubscriptionKind::Exact,
1 => redis::PubSubSubscriptionKind::Pattern,
2 => redis::PubSubSubscriptionKind::Sharded,
3_u32..=u32::MAX => {
log_warn(
"client creation",
format!(
"Omitting pubsub subscription on an unknown type: {:?}",
*pubsub_type
),
);
continue;
}
};

for channel_pattern in channels_patterns.channels_or_patterns.iter() {
redis_pubsub
.entry(kind)
.and_modify(|channels_patterns| {
channels_patterns.insert(channel_pattern.to_vec());
})
.or_insert(HashSet::from([channel_pattern.to_vec()]));
}
}
pubsub_subscriptions = Some(redis_pubsub);
}

ConnectionRequest {
read_from,
Expand All @@ -163,6 +199,7 @@ impl From<protobuf::ConnectionRequest> for ConnectionRequest {
request_timeout,
connection_retry_strategy,
periodic_checks,
pubsub_subscriptions,
}
}
}
Loading

0 comments on commit 18e6a9b

Please sign in to comment.