Skip to content

Commit

Permalink
Add RPC style call for WorkerSocket and limit the payload size to 64MB
Browse files Browse the repository at this point in the history
  • Loading branch information
Champii committed Jul 2, 2024
1 parent dfd8366 commit a0272b8
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 48 deletions.
22 changes: 7 additions & 15 deletions host/src/server/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use tracing::{error, info, warn};
async fn handle_worker_socket(mut socket: WorkerSocket) -> Result<(), ProverError> {
let protocol = socket.receive().await?;

info!("Received request: {}", protocol);
info!("Received request from orchestrator: {}", protocol);

match protocol {
WorkerProtocol::Ping => {
socket.send(WorkerProtocol::Ping).await?;
socket.send(WorkerProtocol::Pong).await?;
}
WorkerProtocol::PartialProofRequest(data) => {
process_partial_proof_request(socket, data).await?;
Expand All @@ -26,18 +26,13 @@ async fn process_partial_proof_request(
mut socket: WorkerSocket,
data: PartialProofRequest,
) -> Result<(), ProverError> {
let result = sp1_driver::Sp1DistributedProver::run_as_worker(data).await;
let partial_proof = sp1_driver::Sp1DistributedProver::run_as_worker(data).await?;

match result {
Ok(data) => Ok(socket
.send(WorkerProtocol::PartialProofResponse(data))
.await?),
Err(e) => {
error!("Error while processing worker request: {:?}", e);
socket
.send(WorkerProtocol::PartialProofResponse(partial_proof))
.await?;

Err(e)
}
}
Ok(())
}

async fn listen_worker(state: ProverState) {
Expand All @@ -63,11 +58,8 @@ async fn listen_worker(state: ProverState) {
}
}

info!("Receiving connection from orchestrator: {}", addr);

// We purposely don't spawn the task here, as we want to block to limit the number
// of concurrent connections to one.

if let Err(e) = handle_worker_socket(WorkerSocket::new(socket)).await {
error!("Error while handling worker socket: {:?}", e);
}
Expand Down
2 changes: 2 additions & 0 deletions lib/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,6 @@ pub enum WorkerError {
InvalidRequest,
#[error("Worker invalid response")]
InvalidResponse,
#[error("Worker payload too big")]
PayloadTooBig,
}
16 changes: 2 additions & 14 deletions provers/sp1/driver/src/distributed/orchestrator/worker_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ use async_channel::{Receiver, Sender};
use raiko_lib::prover::WorkerError;
use sp1_core::{runtime::ExecutionState, stark::ShardProof, utils::BabyBearPoseidon2};

use crate::{
distributed::partial_proof_request::PartialProofRequest, WorkerProtocol, WorkerSocket,
};
use crate::{distributed::partial_proof_request::PartialProofRequest, WorkerSocket};

pub struct WorkerClient {
/// The id of the worker
Expand Down Expand Up @@ -84,16 +82,6 @@ impl WorkerClient {
request.checkpoint_id = i;
request.checkpoint_data = checkpoint;

socket
.send(WorkerProtocol::PartialProofRequest(request))
.await?;

let response = socket.receive().await?;

if let WorkerProtocol::PartialProofResponse(partial_proofs) = response {
Ok(partial_proofs)
} else {
Err(WorkerError::InvalidResponse)
}
socket.partial_proof_request(request).await
}
}
18 changes: 4 additions & 14 deletions provers/sp1/driver/src/distributed/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
partial_proof_request::PartialProofRequest,
sp1_specifics::{commit, prove_partial},
},
Sp1Response, WorkerProtocol, WorkerSocket, ELF,
Sp1Response, WorkerSocket, ELF,
};

pub struct Sp1DistributedProver;
Expand Down Expand Up @@ -123,23 +123,13 @@ impl Sp1DistributedProver {
continue;
};

if let Err(_) = socket.send(WorkerProtocol::Ping).await {
log::warn!("Sp1 Distributed: Worker at {} is not reachable. Removing from the list for this task", ip);
if let Err(_) = socket.ping().await {
log::warn!("Sp1 Distributed: Worker at {} is not sending good response to Ping. Removing from the list for this task", ip);

continue;
}

let Ok(response) = socket.receive().await else {
log::warn!("Sp1 Distributed: Worker at {} is not a valid SP1 worker. Removing from the list for this task", ip);

continue;
};

if let WorkerProtocol::Ping = response {
reachable_ip_list.push(ip.clone());
} else {
log::warn!("Sp1 Distributed: Worker at {} is not a valid SP1 worker. Removing from the list for this task", ip);
}
reachable_ip_list.push(ip.clone());
}

if reachable_ip_list.is_empty() {
Expand Down
4 changes: 3 additions & 1 deletion provers/sp1/driver/src/distributed/sp1_specifics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,13 @@ pub fn prove_partial(request_data: &PartialProofRequest) -> Vec<ShardProof<BabyB

log::debug!("Checkpoint sharding took {:?}", now.elapsed());

let nb_shards = checkpoint_shards.len();

let mut proofs = checkpoint_shards
.into_iter()
.enumerate()
.map(|(i, shard)| {
log::info!("Proving shard {}/{}", i + 1, request_data.shard_batch_size);
log::info!("Proving shard {}/{}", i + 1, nb_shards);

let config = machine.config();

Expand Down
2 changes: 2 additions & 0 deletions provers/sp1/driver/src/distributed/worker/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::PartialProofRequest;
#[derive(Debug, Serialize, Deserialize)]
pub enum WorkerProtocol {
Ping,
Pong,
PartialProofRequest(PartialProofRequest),
PartialProofResponse(Vec<ShardProof<BabyBearPoseidon2>>),
}
Expand All @@ -16,6 +17,7 @@ impl Display for WorkerProtocol {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
WorkerProtocol::Ping => write!(f, "Ping"),
WorkerProtocol::Pong => write!(f, "Pong"),
WorkerProtocol::PartialProofRequest(_) => write!(f, "PartialProofRequest"),
WorkerProtocol::PartialProofResponse(_) => write!(f, "PartialProofResponse"),
}
Expand Down
46 changes: 42 additions & 4 deletions provers/sp1/driver/src/distributed/worker/socket.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use raiko_lib::prover::WorkerError;
use sp1_core::{stark::ShardProof, utils::BabyBearPoseidon2};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};

use crate::{WorkerEnvelope, WorkerProtocol};
use crate::{PartialProofRequest, WorkerEnvelope, WorkerProtocol};

// 64MB
const PAYLOAD_MAX_SIZE: usize = 1 << 26;

pub struct WorkerSocket {
pub socket: tokio::net::TcpStream,
Expand All @@ -23,6 +27,10 @@ impl WorkerSocket {

let data = bincode::serialize(&envelope)?;

if data.len() > PAYLOAD_MAX_SIZE {
return Err(WorkerError::PayloadTooBig);
}

self.socket.write_u64(data.len() as u64).await?;
self.socket.write_all(&data).await?;

Expand All @@ -42,10 +50,13 @@ impl WorkerSocket {
}

// TODO: Add a timeout
pub async fn read_data(&mut self) -> Result<Vec<u8>, std::io::Error> {
// TODO: limit the size of the data
pub async fn read_data(&mut self) -> Result<Vec<u8>, WorkerError> {
let size = self.socket.read_u64().await? as usize;

if size > PAYLOAD_MAX_SIZE {
return Err(WorkerError::PayloadTooBig);
}

let mut data = Vec::new();

let mut buf_data = BufWriter::new(&mut data);
Expand All @@ -72,9 +83,36 @@ impl WorkerSocket {
Err(e) => {
log::error!("failed to read from socket; err = {:?}", e);

return Err(e);
return Err(e.into());
}
};
}
}

pub async fn ping(&mut self) -> Result<(), WorkerError> {
self.send(WorkerProtocol::Ping).await?;

let response = self.receive().await?;

match response {
WorkerProtocol::Pong => Ok(()),
_ => Err(WorkerError::InvalidResponse),
}
}

pub async fn partial_proof_request(
&mut self,
request: PartialProofRequest,
) -> Result<Vec<ShardProof<BabyBearPoseidon2>>, WorkerError> {
self.send(WorkerProtocol::PartialProofRequest(request))
.await?;

let response = self.receive().await?;

if let WorkerProtocol::PartialProofResponse(partial_proofs) = response {
Ok(partial_proofs)
} else {
Err(WorkerError::InvalidResponse)
}
}
}

0 comments on commit a0272b8

Please sign in to comment.