Skip to content

Commit

Permalink
fix: race conditions, Byzantine tolerance, max message size, timestam…
Browse files Browse the repository at this point in the history
…p validation
  • Loading branch information
royvardhan committed Nov 12, 2024
1 parent 268db34 commit 55c36fc
Showing 1 changed file with 57 additions and 24 deletions.
81 changes: 57 additions & 24 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ use thiserror::Error;
use tracing::{error, info, warn};
use uuid::Uuid;

const MAX_MESSAGE_SIZE: usize = 1024 * 1024; // 1MB limit

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum MessageType {
Proposal,
Expand Down Expand Up @@ -63,6 +65,7 @@ struct Node {
tx: mpsc::Sender<Message>,
proposal_acknowledgement: Arc<Mutex<HashMap<String, HashSet<u64>>>>,
config: NodeConfig,
last_commit_timestamp: u64,
}

#[derive(Error, Debug)]
Expand All @@ -75,6 +78,8 @@ pub enum NodeError {
ConsensusTimeout,
#[error("Message send failed")]
MessageSendFailed,
#[error("Message too large")]
MessageTooLarge,
}

#[derive(Clone)]
Expand Down Expand Up @@ -155,8 +160,11 @@ impl Node {
}

async fn try_send_message(&self, message: &Message, receiver: &str) -> Result<(), NodeError> {
let mut stream = TcpStream::connect(receiver).await?;
let ser_message = serde_json::to_vec(message)?;
if ser_message.len() > MAX_MESSAGE_SIZE {
return Err(NodeError::MessageTooLarge);
}
let mut stream = TcpStream::connect(receiver).await?;
stream.write_all(&ser_message).await?;
Ok(())
}
Expand Down Expand Up @@ -194,7 +202,7 @@ impl Node {
self.wait_for_acknowledgements(proposal_id).await
}

async fn handle_incoming_messages(&self, mut rx: mpsc::Receiver<Message>) {
async fn handle_incoming_messages(&mut self, mut rx: mpsc::Receiver<Message>) {
while let Some(message) = rx.recv().await {
counter!("messages_received", 1);
match message.message_type {
Expand Down Expand Up @@ -264,13 +272,16 @@ impl Node {
}
}

async fn handle_commit(&self, message: Message) {
async fn handle_commit(&mut self, message: Message) {
let mut current_state = self.state.lock().await;
if message.timestamp < self.last_commit_timestamp {
warn!("Received outdated commit message");
return;
}
if current_state.can_transition_to(&message.proposed_state) {
*current_state = message.proposed_state;
self.last_commit_timestamp = message.timestamp;
info!("State updated to {:?}", *current_state);
} else {
warn!("Invalid state transition attempted");
}
}

Expand Down Expand Up @@ -300,7 +311,8 @@ impl Node {
}

async fn wait_for_acknowledgements(&self, proposal_id: String) -> Result<(), NodeError> {
let majority = (self.peers.len() / 2) + 1;
let total_nodes = self.peers.len() + 1; // Include self
let majority = (total_nodes * 2 / 3) + 1; // Use 2/3 majority for Byzantine tolerance
let timeout = tokio::time::sleep(self.config.consensus_timeout);
tokio::pin!(timeout);

Expand Down Expand Up @@ -383,38 +395,44 @@ async fn main() -> Result<(), NodeError> {

let (sender1, receiver1) = mpsc::channel(32);

let node1 = Arc::new(Node {
let node1 = Arc::new(Mutex::new(Node {
id: 1,
state: state.clone(),
peers: HashMap::from([(2, "0.0.0.0:8081".to_string())]),
tx: sender1,
address: "0.0.0.0:8080".to_string(),
proposal_acknowledgement: proposal_acknowledgments.clone(),
config: config.clone(),
});
last_commit_timestamp: 0,
}));

let (sender2, receiver2) = mpsc::channel(32);

let node2 = Arc::new(Node {
let node2 = Arc::new(Mutex::new(Node {
id: 2,
state: state.clone(),
peers: HashMap::from([(1, "0.0.0.0:8080".to_string())]),
tx: sender2,
address: "0.0.0.0:8081".to_string(),
proposal_acknowledgement: proposal_acknowledgments,
config: config.clone(),
});
last_commit_timestamp: 0,
}));

let node1_clone_for_messages = Arc::clone(&node1);
spawn(async move {
node1_clone_for_messages
.lock()
.await
.handle_incoming_messages(receiver1)
.await;
});

let node2_clone_for_messages = Arc::clone(&node1);
let node2_clone_for_messages = Arc::clone(&node2);
spawn(async move {
node2_clone_for_messages
.lock()
.await
.handle_incoming_messages(receiver2)
.await;
});
Expand All @@ -423,6 +441,8 @@ async fn main() -> Result<(), NodeError> {
let node1_clone_for_listen = Arc::clone(&node1);
tokio::spawn(async move {
node1_clone_for_listen
.lock()
.await
.listen()
.await
.expect("Node 1 failed to listen");
Expand All @@ -431,6 +451,8 @@ async fn main() -> Result<(), NodeError> {
let node2_clone_for_listen = Arc::clone(&node2);
tokio::spawn(async move {
node2_clone_for_listen
.lock()
.await
.listen()
.await
.expect("Node 2 failed to listen");
Expand All @@ -440,7 +462,11 @@ async fn main() -> Result<(), NodeError> {
tokio::time::sleep(Duration::from_secs(1)).await;

// Use the original `node1` Arc to broadcast a proposal
node1.broadcast_proposal(State::Running).await?;
node1
.lock()
.await
.broadcast_proposal(State::Running)
.await?;

// Start the simulation after a short delay to ensure nodes are listening
tokio::time::sleep(Duration::from_secs(2)).await;
Expand All @@ -461,20 +487,21 @@ mod tests {
use tokio::sync::mpsc;

// Helper function to create a test node
async fn create_test_node(id: u64, address: &str) -> Arc<Node> {
async fn create_test_node(id: u64, address: &str) -> Arc<Mutex<Node>> {
let state = Arc::new(Mutex::new(State::Init));
let proposal_acknowledgments = Arc::new(Mutex::new(HashMap::new()));
let (tx, _) = mpsc::channel(32);

Arc::new(Node {
Arc::new(Mutex::new(Node {
id,
state: state.clone(),
peers: HashMap::new(),
address: address.to_string(),
tx,
proposal_acknowledgement: proposal_acknowledgments,
config: NodeConfig::default(),
})
last_commit_timestamp: 0,
}))
}

#[tokio::test]
Expand All @@ -483,14 +510,16 @@ mod tests {

// Test valid transitions
{
let mut state = node.state.lock().await;
let node_lock = node.lock().await;
let mut state = node_lock.state.lock().await;
assert_eq!(*state, State::Init);
assert!(state.can_transition_to(&State::Running));
*state = State::Running;
}

{
let mut state = node.state.lock().await;
let node_lock = node.lock().await;
let mut state = node_lock.state.lock().await;
assert_eq!(*state, State::Running);
assert!(state.can_transition_to(&State::Stopped));
assert!(!state.can_transition_to(&State::Init));
Expand All @@ -503,7 +532,7 @@ mod tests {
let proposal_id = Uuid::new_v4().to_string();

let message = Message {
sender_id: node.id,
sender_id: node.lock().await.id,
message_type: MessageType::Proposal,
proposed_state: State::Running,
proposal_id: proposal_id.clone(),
Expand Down Expand Up @@ -610,16 +639,18 @@ mod tests {

// Initial state should be Init
{
let state = node.state.lock().await;
let node_lock = node.lock().await;
let state = node_lock.state.lock().await;
assert_eq!(*state, State::Init);
}

// Handle commit message
node.handle_commit(commit_message).await;
node.lock().await.handle_commit(commit_message).await;

// State should be updated to Running
{
let state = node.state.lock().await;
let node_lock = node.lock().await;
let state = node_lock.state.lock().await;
assert_eq!(*state, State::Running);
}
}
Expand All @@ -643,16 +674,18 @@ mod tests {

// Initial state should be Init
{
let state = node.state.lock().await;
let node_lock = node.lock().await;
let state = node_lock.state.lock().await;
assert_eq!(*state, State::Init);
}

// Handle commit message
node.handle_commit(commit_message).await;
node.lock().await.handle_commit(commit_message).await;

// State should still be Init
{
let state = node.state.lock().await;
let node_lock = node.lock().await;
let state = node_lock.state.lock().await;
assert_eq!(*state, State::Init);
}
}
Expand Down

0 comments on commit 55c36fc

Please sign in to comment.