Skip to content

Commit

Permalink
feat(raiko): merge stress test upgrades (#392)
Browse files Browse the repository at this point in the history
* update risc0-zkvm  to v1.1.2

* use async sleep to avoid deadlock

* update docker build

* verify zk proof by default

* update to sp1-sdk 3.0.0

* fix: print error in run_task

* update patches (#388)

* fix typo

* fix stress test merge issues

* fix sp1 verification fixture format

* fix risc0 verification

* update env script

---------

Co-authored-by: john xu <john@taiko.xyz>
Co-authored-by: Chris T. <chris@succinct.xyz>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 8471b16 commit 7f64cbe
Show file tree
Hide file tree
Showing 19 changed files with 631 additions and 502 deletions.
640 changes: 306 additions & 334 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ reth-chainspec = { git = "https://github.com/taikoxyz/taiko-reth.git", branch =
reth-provider = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false }

# risc zero
risc0-zkvm = { version = "1.0.1", features = ["prove", "getrandom"] }
bonsai-sdk = { version = "0.8.0", features = ["async"] }
risc0-build = { version = "1.0.1" }
risc0-binfmt = { version = "1.0.1" }
risc0-zkvm = { version = "=1.1.2", features = ["prove", "getrandom"] }
bonsai-sdk = { version = "=1.1.2" }
risc0-binfmt = { version = "=1.1.2" }

# SP1
sp1-sdk = { version = "2.0.0" }
sp1-sdk = { version = "=3.0.0-rc3" }
sp1-zkvm = { version = "2.0.0" }
sp1-helper = { version = "2.0.0" }

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.zk
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ RUN echo "Building for sp1"
ENV TARGET=sp1
RUN make install
RUN make guest
RUN cargo build --release ${BUILD_FLAGS} --features "sp1,risc0,bonsai-auto-scaling" --features "docker_build"
RUN cargo build --release ${BUILD_FLAGS} --features "sp1,risc0" --features "docker_build"

RUN mkdir -p \
./bin \
Expand Down
7 changes: 5 additions & 2 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,11 @@ services:
volumes:
- /var/log/raiko:/var/log/raiko
ports:
- "8081:8080"
- "8080:8080"
environment:
# you can use your own PCCS host
# - PCCS_HOST=host.docker.internal:8081
- RUST_LOG=${RUST_LOG:-info}
- ZK=true
- ETHEREUM_RPC=${ETHEREUM_RPC}
- ETHEREUM_BEACON_RPC=${ETHEREUM_BEACON_RPC}
Expand All @@ -145,11 +146,13 @@ services:
- NETWORK=${NETWORK}
- BONSAI_API_KEY=${BONSAI_API_KEY}
- BONSAI_API_URL=${BONSAI_API_URL}
- MAX_BONSAI_GPU_NUM=15
- MAX_BONSAI_GPU_NUM=300
- GROTH16_VERIFIER_RPC_URL=${GROTH16_VERIFIER_RPC_URL}
- GROTH16_VERIFIER_ADDRESS=${GROTH16_VERIFIER_ADDRESS}
- SP1_PRIVATE_KEY=${SP1_PRIVATE_KEY}
- SKIP_SIMULATION=true
- SP1_VERIFIER_RPC_URL=${SP1_VERIFIER_RPC_URL}
- SP1_VERIFIER_ADDRESS=${SP1_VERIFIER_ADDRESS}
pccs:
build:
context: ..
Expand Down
11 changes: 7 additions & 4 deletions host/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,18 @@ impl ProofActor {
pub async fn run_task(&mut self, proof_request: ProofRequest) {
let cancel_token = CancellationToken::new();

let Ok((chain_id, blockhash)) = get_task_data(
let (chain_id, blockhash) = match get_task_data(
&proof_request.network,
proof_request.block_number,
&self.chain_specs,
)
.await
else {
error!("Could not get task data for {proof_request:?}");
return;
{
Ok(v) => v,
Err(e) => {
error!("Could not get task data for {proof_request:?}, error: {e}");
return;
}
};

let key = TaskDescriptor::from((
Expand Down
60 changes: 39 additions & 21 deletions provers/risc0/driver/src/bonsai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
Risc0Response,
};
use alloy_primitives::B256;
use bonsai_sdk::blocking::{Client, SessionId};
use log::{debug, error, info, warn};
use raiko_lib::{
primitives::keccak::keccak,
Expand All @@ -19,14 +20,17 @@ use std::{
fs,
path::{Path, PathBuf},
};
use tokio::time::{sleep as tokio_async_sleep, Duration};

use crate::Risc0Param;

const MAX_REQUEST_RETRY: usize = 8;

#[derive(thiserror::Error, Debug)]
pub enum BonsaiExecutionError {
// common errors: include sdk error, or some others from non-bonsai code
#[error(transparent)]
SdkFailure(#[from] bonsai_sdk::alpha::SdkErr),
SdkFailure(#[from] bonsai_sdk::SdkErr),
#[error("bonsai execution error: {0}")]
Other(String),
// critical error like OOM, which is un-recoverable
Expand All @@ -44,12 +48,12 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
max_retries: usize,
) -> Result<(String, Receipt), BonsaiExecutionError> {
info!("Tracking receipt uuid: {uuid}");
let session = bonsai_sdk::alpha::SessionId { uuid };
let session = SessionId { uuid };

loop {
let mut res = None;
for attempt in 1..=max_retries {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;

match session.status(&client) {
Ok(response) => {
Expand All @@ -61,7 +65,7 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
return Err(BonsaiExecutionError::SdkFailure(err));
}
warn!("Attempt {attempt}/{max_retries} for session status request: {err:?}");
std::thread::sleep(std::time::Duration::from_secs(15));
tokio_async_sleep(Duration::from_secs(15)).await;
continue;
}
}
Expand All @@ -72,17 +76,18 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(

if res.status == "RUNNING" {
info!(
"Current status: {} - state: {} - continue polling...",
"Current {session:?} status: {} - state: {} - continue polling...",
res.status,
res.state.unwrap_or_default()
);
std::thread::sleep(std::time::Duration::from_secs(15));
tokio_async_sleep(Duration::from_secs(15)).await;
} else if res.status == "SUCCEEDED" {
// Download the receipt, containing the output
info!("Prove task {session:?} success.");
let receipt_url = res
.receipt_url
.expect("API error, missing receipt on completed session");
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
let receipt_buf = client.download(&receipt_url)?;
let receipt: Receipt = bincode::deserialize(&receipt_buf).map_err(|e| {
BonsaiExecutionError::Other(format!("Failed to deserialize receipt: {e:?}"))
Expand All @@ -104,10 +109,10 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
}
return Ok((session.uuid, receipt));
} else {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
let bonsai_err_log = session.logs(&client);
return Err(BonsaiExecutionError::Fatal(format!(
"Workflow exited: {} - | err: {} | log: {bonsai_err_log:?}",
"Workflow {session:?} exited: {} - | err: {} | log: {bonsai_err_log:?}",
res.status,
res.error_msg.unwrap_or_default(),
)));
Expand Down Expand Up @@ -167,11 +172,11 @@ pub async fn maybe_prove<I: Serialize, O: Eq + Debug + Serialize + DeserializeOw
}
Err(BonsaiExecutionError::SdkFailure(err)) => {
warn!("Bonsai SDK fail: {err:?}, keep tracking...");
std::thread::sleep(std::time::Duration::from_secs(15));
tokio_async_sleep(Duration::from_secs(15)).await;
}
Err(BonsaiExecutionError::Other(err)) => {
warn!("Something wrong: {err:?}, keep tracking...");
std::thread::sleep(std::time::Duration::from_secs(15));
tokio_async_sleep(Duration::from_secs(15)).await;
}
Err(BonsaiExecutionError::Fatal(err)) => {
error!("Fatal error on Bonsai: {err:?}");
Expand Down Expand Up @@ -228,13 +233,13 @@ pub async fn maybe_prove<I: Serialize, O: Eq + Debug + Serialize + DeserializeOw
}

pub async fn upload_receipt(receipt: &Receipt) -> anyhow::Result<String> {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
Ok(client.upload_receipt(bincode::serialize(receipt)?)?)
}

pub async fn cancel_proof(uuid: String) -> anyhow::Result<()> {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let session = bonsai_sdk::alpha::SessionId { uuid };
let client = Client::from_env(risc0_zkvm::VERSION)?;
let session = SessionId { uuid };
session.stop(&client)?;
#[cfg(feature = "bonsai-auto-scaling")]
auto_scaling::shutdown_bonsai().await?;
Expand All @@ -257,7 +262,7 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
// Prepare input data
let input_data = bytemuck::cast_slice(&encoded_input).to_vec();

let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
client.upload_img(&encoded_image_id, elf.to_vec())?;
// upload input
let input_id = client.upload_input(input_data.clone())?;
Expand All @@ -266,6 +271,7 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
encoded_image_id.clone(),
input_id.clone(),
assumption_uuids.clone(),
false,
)?;

if let Some(id_store) = id_store {
Expand All @@ -277,7 +283,13 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
})?;
}

verify_bonsai_receipt(image_id, expected_output, session.uuid.clone(), 8).await
verify_bonsai_receipt(
image_id,
expected_output,
session.uuid.clone(),
MAX_REQUEST_RETRY,
)
.await
}

pub async fn bonsai_stark_to_snark(
Expand All @@ -286,10 +298,14 @@ pub async fn bonsai_stark_to_snark(
input: B256,
) -> ProverResult<Risc0Response> {
let image_id = Digest::from(RISC0_GUEST_ID);
let (snark_uuid, snark_receipt) =
stark2snark(image_id, stark_uuid.clone(), stark_receipt.clone())
.await
.map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?;
let (snark_uuid, snark_receipt) = stark2snark(
image_id,
stark_uuid.clone(),
stark_receipt.clone(),
MAX_REQUEST_RETRY,
)
.await
.map_err(|err| format!("Failed to convert STARK to SNARK: {err:?}"))?;

info!("Validating SNARK uuid: {snark_uuid}");

Expand Down Expand Up @@ -382,8 +398,10 @@ pub fn load_receipt<T: serde::de::DeserializeOwned>(

pub fn save_receipt<T: serde::Serialize>(receipt_label: &String, receipt_data: &(String, T)) {
if !is_dev_mode() {
let cache_path = zkp_cache_path(receipt_label);
info!("Saving receipt to cache: {cache_path:?}");
fs::write(
zkp_cache_path(receipt_label),
cache_path,
bincode::serialize(receipt_data).expect("Failed to serialize receipt!"),
)
.expect("Failed to save receipt output file.");
Expand Down
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/methods/risc0_aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub const RISC0_AGGREGATION_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-aggregation");
pub const RISC0_AGGREGATION_ID: [u32; 8] = [
3593026424, 359928015, 3488866833, 2676323972, 1129344711, 55769507, 233041442, 3293280986,
3190692238, 1991537256, 2457220677, 1764592515, 1585399420, 97928005, 276688816, 447831862,
];
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/methods/risc0_guest.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub const RISC0_GUEST_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/risc0-guest");
pub const RISC0_GUEST_ID: [u32; 8] = [
2522428380, 1790994278, 397707036, 244564411, 3780865207, 1282154214, 1673205005, 3172292887,
3473581204, 2561439051, 2320161003, 3018340632, 1481329104, 1608433297, 3314099706, 2669934765,
];
47 changes: 34 additions & 13 deletions provers/risc0/driver/src/snarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::{str::FromStr, sync::Arc};
use alloy_primitives::B256;
use alloy_sol_types::{sol, SolValue};
use anyhow::Result;
use bonsai_sdk::alpha::responses::SnarkReceipt;
use bonsai_sdk::blocking::Client;
use ethers_contract::abigen;
use ethers_core::types::H160;
use ethers_providers::{Http, Provider, RetryClient};
Expand All @@ -27,6 +27,7 @@ use risc0_zkvm::{
sha::{Digest, Digestible},
Groth16ReceiptVerifierParameters, Receipt,
};
use tokio::time::{sleep as tokio_async_sleep, Duration};

use tracing::{error as tracing_err, info as tracing_info};

Expand Down Expand Up @@ -86,7 +87,8 @@ pub async fn stark2snark(
image_id: Digest,
stark_uuid: String,
stark_receipt: Receipt,
) -> Result<(String, SnarkReceipt)> {
max_retries: usize,
) -> Result<(String, Receipt)> {
info!("Submitting SNARK workload");
// Label snark output as journal digest
let receipt_label = format!(
Expand All @@ -106,20 +108,38 @@ pub async fn stark2snark(
stark_uuid
};

let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let snark_uuid = client.create_snark(stark_uuid)?;
let client = Client::from_env(risc0_zkvm::VERSION)?;
let snark_uuid = client.create_snark(stark_uuid.clone())?;

let mut retry = 0;
let snark_receipt = loop {
let res = snark_uuid.status(&client)?;

if res.status == "RUNNING" {
info!("Current status: {} - continue polling...", res.status);
std::thread::sleep(std::time::Duration::from_secs(15));
info!(
"Current {:?} status: {} - continue polling...",
&stark_uuid, res.status
);
tokio_async_sleep(Duration::from_secs(15)).await;
} else if res.status == "SUCCEEDED" {
break res
let download_url = res
.output
.expect("Bonsai response is missing SnarkReceipt.");
let receipt_buf = client.download(&download_url)?;
let snark_receipt: Receipt = bincode::deserialize(&receipt_buf)?;
break snark_receipt;
} else {
if retry < max_retries {
retry += 1;
info!(
"Workflow {:?} exited: {} - | err: {} - retrying {}/{max_retries}",
stark_uuid,
res.status,
res.error_msg.unwrap_or_default(),
retry
);
tokio_async_sleep(Duration::from_secs(15)).await;
continue;
}
panic!(
"Workflow exited: {} - | err: {}",
res.status,
Expand All @@ -129,15 +149,15 @@ pub async fn stark2snark(
};

let stark_psd = stark_receipt.claim()?.as_value().unwrap().post.digest();
let snark_psd = Digest::try_from(snark_receipt.post_state_digest.as_slice())?;
let snark_psd = snark_receipt.claim()?.as_value().unwrap().post.digest();

if stark_psd != snark_psd {
error!("SNARK/STARK Post State Digest mismatch!");
error!("STARK: {}", hex::encode(stark_psd));
error!("SNARK: {}", hex::encode(snark_psd));
}

if snark_receipt.journal != stark_receipt.journal.bytes {
if snark_receipt.journal.bytes != stark_receipt.journal.bytes {
error!("SNARK/STARK Receipt Journal mismatch!");
error!("STARK: {}", hex::encode(&stark_receipt.journal.bytes));
error!("SNARK: {}", hex::encode(&snark_receipt.journal));
Expand All @@ -152,11 +172,12 @@ pub async fn stark2snark(

pub async fn verify_groth16_from_snark_receipt(
image_id: Digest,
snark_receipt: SnarkReceipt,
snark_receipt: Receipt,
) -> Result<Vec<u8>> {
let seal = encode(snark_receipt.snark.to_vec())?;
let groth16_claim = snark_receipt.inner.groth16().unwrap();
let seal = groth16_claim.seal.clone();
let journal_digest = snark_receipt.journal.digest();
let post_state_digest = snark_receipt.post_state_digest.digest();
let post_state_digest = snark_receipt.claim()?.as_value().unwrap().post.digest();
let encoded_proof =
verify_groth16_snark_impl(image_id, seal, journal_digest, post_state_digest).await?;
let proof = (encoded_proof, B256::from_slice(image_id.as_bytes()))
Expand Down
Loading

0 comments on commit 7f64cbe

Please sign in to comment.