Skip to content

Commit

Permalink
First working version of SP1 Distributed Prover
Browse files Browse the repository at this point in the history
  • Loading branch information
Champii committed Jul 3, 2024
1 parent 48ea079 commit c52bc22
Show file tree
Hide file tree
Showing 25 changed files with 1,234 additions and 136 deletions.
233 changes: 168 additions & 65 deletions Cargo.lock

Large diffs are not rendered by default.

26 changes: 23 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,19 @@ risc0-build = { version = "0.21.0" }
risc0-binfmt = { version = "0.21.0" }

# SP1
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
sp1-core = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }


# Plonky3
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }


# alloy
alloy-rlp = { version = "0.3.4", default-features = false }
Expand Down Expand Up @@ -149,6 +159,7 @@ secp256k1 = { version = "0.29", default-features = false, features = [
"global-context",
"recovery",
] }
async-channel = "2.3.1"

# macro
syn = { version = "1.0", features = ["full"] }
Expand Down Expand Up @@ -188,3 +199,12 @@ revm-primitives = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-
revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
secp256k1 = { git = "https://github.com/CeciliaZ030/rust-secp256k1", branch = "sp1-patch" }
blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" }

# Patch Plonky3 for Serialize and Deserialize of DuplexChallenger
[patch."https://github.com/Plonky3/Plonky3.git"]
p3-field = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
p3-challenger = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
p3-poseidon2 = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
p3-baby-bear = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
p3-symmetric = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }

14 changes: 14 additions & 0 deletions core/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ pub enum ProofType {
///
/// Uses the SP1 prover to build the block.
Sp1,
/// # Sp1Distributed
///
/// Uses the SP1 prover to build the block in a distributed way.
Sp1Distributed,
/// # Sgx
///
/// Builds the block on a SGX supported CPU to create a proof.
Expand All @@ -119,6 +123,7 @@ impl std::fmt::Display for ProofType {
f.write_str(match self {
ProofType::Native => "native",
ProofType::Sp1 => "sp1",
ProofType::Sp1Distributed => "sp1_distributed",
ProofType::Sgx => "sgx",
ProofType::Risc0 => "risc0",
})
Expand All @@ -132,6 +137,7 @@ impl FromStr for ProofType {
match s.trim().to_lowercase().as_str() {
"native" => Ok(ProofType::Native),
"sp1" => Ok(ProofType::Sp1),
"sp1_distributed" => Ok(ProofType::Sp1Distributed),
"sgx" => Ok(ProofType::Sgx),
"risc0" => Ok(ProofType::Risc0),
_ => Err(RaikoError::InvalidProofType(s.to_string())),
Expand Down Expand Up @@ -159,6 +165,14 @@ impl ProofType {
#[cfg(not(feature = "sp1"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Sp1Distributed => {
#[cfg(feature = "sp1")]
return sp1_driver::Sp1DistributedProver::run(input, output, config)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "sp1"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Risc0 => {
#[cfg(feature = "risc0")]
return risc0_driver::Risc0Prover::run(input.clone(), output, config)
Expand Down
3 changes: 2 additions & 1 deletion core/src/preflight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ async fn prepare_taiko_chain_input(
.await?;

// Fetch the tx data from either calldata or blobdata
let (tx_data, blob_commitment) = if proposal_event.meta.blobUsed {
// let (tx_data, blob_commitment) = if proposal_event.meta.blobUsed {
let (tx_data, blob_commitment) = if false {
debug!("blob active");
// Get the blob hashes attached to the propose tx
let blob_hashes = proposal_tx.blob_versioned_hashes.unwrap_or_default();
Expand Down
2 changes: 1 addition & 1 deletion host/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ethers-core = { workspace = true }

[features]
default = []
sp1 = ["raiko-core/sp1"]
sp1 = ["raiko-core/sp1", "sp1-driver"]
risc0 = ["raiko-core/risc0"]
sgx = ["raiko-core/sgx"]

Expand Down
15 changes: 15 additions & 0 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ fn default_address() -> String {
"0.0.0.0:8080".to_string()
}

fn default_worker_address() -> String {
"0.0.0.0:8081".to_string()
}

fn default_concurrency_limit() -> usize {
16
}
Expand Down Expand Up @@ -69,6 +73,17 @@ pub struct Cli {
/// [default: 0.0.0.0:8080]
address: String,

#[arg(long, require_equals = true, default_value = "0.0.0.0:8081")]
#[serde(default = "default_worker_address")]
/// Distributed SP1 worker listening address
/// [default: 0.0.0.0:8081]
worker_address: String,

#[arg(long, default_value = None)]
/// Distributed SP1 worker orchestrator address
/// [default: None]
orchestrator_address: Option<String>,

#[arg(long, require_equals = true, default_value = "16")]
#[serde(default = "default_concurrency_limit")]
/// Limit the max number of in-flight requests
Expand Down
5 changes: 5 additions & 0 deletions host/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ use tokio::net::TcpListener;
use tracing::info;

pub mod api;
#[cfg(feature = "sp1")]
pub mod worker;

/// Starts the proverd server.
pub async fn serve(state: ProverState) -> anyhow::Result<()> {
#[cfg(feature = "sp1")]
worker::serve(state.clone()).await;

let addr = SocketAddr::from_str(&state.opts.address)
.map_err(|_| HostError::InvalidAddress(state.opts.address.clone()))?;
let listener = TcpListener::bind(addr).await?;
Expand Down
73 changes: 73 additions & 0 deletions host/src/server/worker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use crate::ProverState;
use raiko_lib::prover::{ProverError, WorkerError};
use sp1_driver::{PartialProofRequest, WorkerProtocol, WorkerSocket};
use tokio::net::TcpListener;
use tracing::{error, info, warn};

async fn handle_worker_socket(mut socket: WorkerSocket) -> Result<(), ProverError> {
let protocol = socket.receive().await?;

info!("Received request from orchestrator: {}", protocol);

match protocol {
WorkerProtocol::Ping => {
socket.send(WorkerProtocol::Pong).await?;
}
WorkerProtocol::PartialProofRequest(data) => {
process_partial_proof_request(socket, data).await?;
}
_ => Err(WorkerError::InvalidRequest)?,
}

Ok(())
}

async fn process_partial_proof_request(
mut socket: WorkerSocket,
data: PartialProofRequest,
) -> Result<(), ProverError> {
let partial_proof = sp1_driver::Sp1DistributedProver::run_as_worker(data).await?;

socket
.send(WorkerProtocol::PartialProofResponse(partial_proof))
.await?;

Ok(())
}

async fn listen_worker(state: ProverState) {
info!(
"Listening as a SP1 worker on: {}",
state.opts.worker_address
);

let listener = TcpListener::bind(state.opts.worker_address).await.unwrap();

loop {
let Ok((socket, addr)) = listener.accept().await else {
error!("Error while accepting connection from orchestrator: Closing socket");

return;
};

if let Some(orchestrator_address) = &state.opts.orchestrator_address {
if addr.ip().to_string() != *orchestrator_address {
warn!("Unauthorized orchestrator connection from: {}", addr);

continue;
}
}

// We purposely don't spawn the task here, as we want to block to limit the number
// of concurrent connections to one.
if let Err(e) = handle_worker_socket(WorkerSocket::new(socket)).await {
error!("Error while handling worker socket: {:?}", e);
}
}
}

pub async fn serve(state: ProverState) {
if state.opts.orchestrator_address.is_some() {
tokio::spawn(listen_worker(state));
}
}
2 changes: 1 addition & 1 deletion lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ std = [
sgx = []
sp1 = []
risc0 = []
sp1-cycle-tracker = []
sp1-cycle-tracker = []
20 changes: 20 additions & 0 deletions lib/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub enum ProverError {
FileIo(#[from] std::io::Error),
#[error("ProverError::Param `{0}`")]
Param(#[from] serde_json::Error),
#[error("ProverError::Worker `{0}`")]
Worker(#[from] WorkerError),
}

impl From<String> for ProverError {
Expand All @@ -37,3 +39,21 @@ pub fn to_proof(proof: ProverResult<impl Serialize>) -> ProverResult<Proof> {
serde_json::to_value(res).map_err(|err| ProverError::GuestError(err.to_string()))
})
}

#[derive(ThisError, Debug)]
pub enum WorkerError {
#[error("All workers failed")]
AllWorkersFailed,
#[error("Worker IO error: {0}")]
IO(#[from] std::io::Error),
#[error("Worker Serde error: {0}")]
Serde(#[from] bincode::Error),
#[error("Worker invalid magic number")]
InvalidMagicNumber,
#[error("Worker invalid request")]
InvalidRequest,
#[error("Worker invalid response")]
InvalidResponse,
#[error("Worker payload too big")]
PayloadTooBig,
}
15 changes: 15 additions & 0 deletions provers/sp1/driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,32 @@ alloy-sol-types = { workspace = true }
serde = { workspace = true , optional = true}
serde_json = { workspace = true , optional = true }
sp1-sdk = { workspace = true, optional = true }
sp1-core = { workspace = true, optional = true }
anyhow = { workspace = true, optional = true }
once_cell = { workspace = true, optional = true }
sha3 = { workspace = true, optional = true, default-features = false}

log = { workspace = true }
tokio = { workspace = true }
async-channel = { workspace = true }
tracing = { workspace = true }
tempfile = { workspace = true }
bincode = { workspace = true }

p3-field = { workspace = true }
p3-challenger = { workspace = true }
p3-poseidon2 = { workspace = true }
p3-baby-bear = { workspace = true }
p3-symmetric = { workspace = true }


[features]
enable = [
"serde",
"serde_json",
"raiko-lib",
"sp1-sdk",
"sp1-core",
"anyhow",
"alloy-primitives",
"once_cell",
Expand Down
9 changes: 9 additions & 0 deletions provers/sp1/driver/src/distributed/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
mod orchestrator;
mod partial_proof_request;
mod prover;
mod sp1_specifics;
mod worker;

pub use partial_proof_request::PartialProofRequest;
pub use prover::Sp1DistributedProver;
pub use worker::{WorkerEnvelope, WorkerProtocol, WorkerSocket};
79 changes: 79 additions & 0 deletions provers/sp1/driver/src/distributed/orchestrator/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
mod worker_client;

use raiko_lib::prover::WorkerError;
use sp1_core::{runtime::ExecutionState, stark::ShardProof, utils::BabyBearPoseidon2};
use worker_client::WorkerClient;

use super::partial_proof_request::PartialProofRequest;

pub async fn distribute_work(
ip_list: Vec<String>,
checkpoints: Vec<ExecutionState>,
partial_proof_request: PartialProofRequest,
) -> Result<Vec<ShardProof<BabyBearPoseidon2>>, WorkerError> {
let mut nb_workers = ip_list.len();

let (queue_tx, queue_rx) = async_channel::bounded(nb_workers);
let (answer_tx, answer_rx) = async_channel::bounded(nb_workers);

// Spawn the workers
for (i, url) in ip_list.iter().enumerate() {
let worker = WorkerClient::new(
i,
url.clone(),
queue_rx.clone(),
answer_tx.clone(),
partial_proof_request.clone(),
);

tokio::spawn(async move {
worker.run().await;
});
}

// Send the checkpoints to the workers
for (i, checkpoint) in checkpoints.iter().enumerate() {
queue_tx.send((i, checkpoint.clone())).await.unwrap();
}

let mut proofs = Vec::new();

// Get the partial proofs from the workers
loop {
let (checkpoint_id, partial_proof_result) = answer_rx.recv().await.unwrap();

match partial_proof_result {
Ok(partial_proof) => {
proofs.push((checkpoint_id as usize, partial_proof));
}
Err(_e) => {
// Decrease the number of workers
nb_workers -= 1;

if nb_workers == 0 {
return Err(WorkerError::AllWorkersFailed);
}

// Push back the work for it to be done by another worker
queue_tx
.send((checkpoint_id, checkpoints[checkpoint_id as usize].clone()))
.await
.unwrap();
}
}

if proofs.len() == checkpoints.len() {
break;
}
}

proofs.sort_by_key(|(checkpoint_id, _)| *checkpoint_id);

let proofs = proofs
.into_iter()
.map(|(_, proof)| proof)
.flatten()
.collect();

Ok(proofs)
}
Loading

0 comments on commit c52bc22

Please sign in to comment.