From a0272b834a3569f41feeb008610b31e2f284b2c0 Mon Sep 17 00:00:00 2001 From: Champii1 Date: Tue, 2 Jul 2024 19:10:53 +0200 Subject: [PATCH] Add RPC style call for WorkerSocket and limit the payload size to 64MB --- host/src/server/worker.rs | 22 +++------ lib/src/prover.rs | 2 + .../distributed/orchestrator/worker_client.rs | 16 +------ provers/sp1/driver/src/distributed/prover.rs | 18 ++------ .../driver/src/distributed/sp1_specifics.rs | 4 +- .../driver/src/distributed/worker/protocol.rs | 2 + .../driver/src/distributed/worker/socket.rs | 46 +++++++++++++++++-- 7 files changed, 62 insertions(+), 48 deletions(-) diff --git a/host/src/server/worker.rs b/host/src/server/worker.rs index 0555f7001..65ac85a54 100644 --- a/host/src/server/worker.rs +++ b/host/src/server/worker.rs @@ -7,11 +7,11 @@ use tracing::{error, info, warn}; async fn handle_worker_socket(mut socket: WorkerSocket) -> Result<(), ProverError> { let protocol = socket.receive().await?; - info!("Received request: {}", protocol); + info!("Received request from orchestrator: {}", protocol); match protocol { WorkerProtocol::Ping => { - socket.send(WorkerProtocol::Ping).await?; + socket.send(WorkerProtocol::Pong).await?; } WorkerProtocol::PartialProofRequest(data) => { process_partial_proof_request(socket, data).await?; @@ -26,18 +26,13 @@ async fn process_partial_proof_request( mut socket: WorkerSocket, data: PartialProofRequest, ) -> Result<(), ProverError> { - let result = sp1_driver::Sp1DistributedProver::run_as_worker(data).await; + let partial_proof = sp1_driver::Sp1DistributedProver::run_as_worker(data).await?; - match result { - Ok(data) => Ok(socket - .send(WorkerProtocol::PartialProofResponse(data)) - .await?), - Err(e) => { - error!("Error while processing worker request: {:?}", e); + socket + .send(WorkerProtocol::PartialProofResponse(partial_proof)) + .await?; - Err(e) - } - } + Ok(()) } async fn listen_worker(state: ProverState) { @@ -63,11 +58,8 @@ async fn listen_worker(state: ProverState) { } } - info!("Receiving connection from orchestrator: {}", addr); - // We purposely don't spawn the task here, as we want to block to limit the number // of concurrent connections to one. - if let Err(e) = handle_worker_socket(WorkerSocket::new(socket)).await { error!("Error while handling worker socket: {:?}", e); } diff --git a/lib/src/prover.rs b/lib/src/prover.rs index c2efe6b48..4e26c30ca 100644 --- a/lib/src/prover.rs +++ b/lib/src/prover.rs @@ -54,4 +54,6 @@ pub enum WorkerError { InvalidRequest, #[error("Worker invalid response")] InvalidResponse, + #[error("Worker payload too big")] + PayloadTooBig, } diff --git a/provers/sp1/driver/src/distributed/orchestrator/worker_client.rs b/provers/sp1/driver/src/distributed/orchestrator/worker_client.rs index 38bea852f..73203b7fa 100644 --- a/provers/sp1/driver/src/distributed/orchestrator/worker_client.rs +++ b/provers/sp1/driver/src/distributed/orchestrator/worker_client.rs @@ -2,9 +2,7 @@ use async_channel::{Receiver, Sender}; use raiko_lib::prover::WorkerError; use sp1_core::{runtime::ExecutionState, stark::ShardProof, utils::BabyBearPoseidon2}; -use crate::{ - distributed::partial_proof_request::PartialProofRequest, WorkerProtocol, WorkerSocket, -}; +use crate::{distributed::partial_proof_request::PartialProofRequest, WorkerSocket}; pub struct WorkerClient { /// The id of the worker @@ -84,16 +82,6 @@ impl WorkerClient { request.checkpoint_id = i; request.checkpoint_data = checkpoint; - socket - .send(WorkerProtocol::PartialProofRequest(request)) - .await?; - - let response = socket.receive().await?; - - if let WorkerProtocol::PartialProofResponse(partial_proofs) = response { - Ok(partial_proofs) - } else { - Err(WorkerError::InvalidResponse) - } + socket.partial_proof_request(request).await } } diff --git a/provers/sp1/driver/src/distributed/prover.rs b/provers/sp1/driver/src/distributed/prover.rs index 9576bf4c4..7013cbccb 100644 --- a/provers/sp1/driver/src/distributed/prover.rs +++ b/provers/sp1/driver/src/distributed/prover.rs @@ -12,7 +12,7 @@ use crate::{ partial_proof_request::PartialProofRequest, sp1_specifics::{commit, prove_partial}, }, - Sp1Response, WorkerProtocol, WorkerSocket, ELF, + Sp1Response, WorkerSocket, ELF, }; pub struct Sp1DistributedProver; @@ -123,23 +123,13 @@ impl Sp1DistributedProver { continue; }; - if let Err(_) = socket.send(WorkerProtocol::Ping).await { - log::warn!("Sp1 Distributed: Worker at {} is not reachable. Removing from the list for this task", ip); + if let Err(_) = socket.ping().await { + log::warn!("Sp1 Distributed: Worker at {} is not sending good response to Ping. Removing from the list for this task", ip); continue; } - let Ok(response) = socket.receive().await else { - log::warn!("Sp1 Distributed: Worker at {} is not a valid SP1 worker. Removing from the list for this task", ip); - - continue; - }; - - if let WorkerProtocol::Ping = response { - reachable_ip_list.push(ip.clone()); - } else { - log::warn!("Sp1 Distributed: Worker at {} is not a valid SP1 worker. Removing from the list for this task", ip); - } + reachable_ip_list.push(ip.clone()); } if reachable_ip_list.is_empty() { diff --git a/provers/sp1/driver/src/distributed/sp1_specifics.rs b/provers/sp1/driver/src/distributed/sp1_specifics.rs index 12eba3903..6b6d5f558 100644 --- a/provers/sp1/driver/src/distributed/sp1_specifics.rs +++ b/provers/sp1/driver/src/distributed/sp1_specifics.rs @@ -238,11 +238,13 @@ pub fn prove_partial(request_data: &PartialProofRequest) -> Vec>), } @@ -16,6 +17,7 @@ impl Display for WorkerProtocol { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { WorkerProtocol::Ping => write!(f, "Ping"), + WorkerProtocol::Pong => write!(f, "Pong"), WorkerProtocol::PartialProofRequest(_) => write!(f, "PartialProofRequest"), WorkerProtocol::PartialProofResponse(_) => write!(f, "PartialProofResponse"), } diff --git a/provers/sp1/driver/src/distributed/worker/socket.rs b/provers/sp1/driver/src/distributed/worker/socket.rs index f744ba0b7..da64e3a6f 100644 --- a/provers/sp1/driver/src/distributed/worker/socket.rs +++ b/provers/sp1/driver/src/distributed/worker/socket.rs @@ -1,7 +1,11 @@ use raiko_lib::prover::WorkerError; +use sp1_core::{stark::ShardProof, utils::BabyBearPoseidon2}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; -use crate::{WorkerEnvelope, WorkerProtocol}; +use crate::{PartialProofRequest, WorkerEnvelope, WorkerProtocol}; + +// 64MB +const PAYLOAD_MAX_SIZE: usize = 1 << 26; pub struct WorkerSocket { pub socket: tokio::net::TcpStream, @@ -23,6 +27,10 @@ impl WorkerSocket { let data = bincode::serialize(&envelope)?; + if data.len() > PAYLOAD_MAX_SIZE { + return Err(WorkerError::PayloadTooBig); + } + self.socket.write_u64(data.len() as u64).await?; self.socket.write_all(&data).await?; @@ -42,10 +50,13 @@ impl WorkerSocket { } // TODO: Add a timeout - pub async fn read_data(&mut self) -> Result, std::io::Error> { - // TODO: limit the size of the data + pub async fn read_data(&mut self) -> Result, WorkerError> { let size = self.socket.read_u64().await? as usize; + if size > PAYLOAD_MAX_SIZE { + return Err(WorkerError::PayloadTooBig); + } + let mut data = Vec::new(); let mut buf_data = BufWriter::new(&mut data); @@ -72,9 +83,36 @@ impl WorkerSocket { Err(e) => { log::error!("failed to read from socket; err = {:?}", e); - return Err(e); + return Err(e.into()); } }; } } + + pub async fn ping(&mut self) -> Result<(), WorkerError> { + self.send(WorkerProtocol::Ping).await?; + + let response = self.receive().await?; + + match response { + WorkerProtocol::Pong => Ok(()), + _ => Err(WorkerError::InvalidResponse), + } + } + + pub async fn partial_proof_request( + &mut self, + request: PartialProofRequest, + ) -> Result>, WorkerError> { + self.send(WorkerProtocol::PartialProofRequest(request)) + .await?; + + let response = self.receive().await?; + + if let WorkerProtocol::PartialProofResponse(partial_proofs) = response { + Ok(partial_proofs) + } else { + Err(WorkerError::InvalidResponse) + } + } }