Skip to content

Commit

Permalink
fix(sdk): better error handling (#1643)
Browse files Browse the repository at this point in the history
  • Loading branch information
ctian1 authored Oct 16, 2024
1 parent d407bc3 commit 3d0c989
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 20 deletions.
28 changes: 23 additions & 5 deletions book/generating-proofs/prover-network/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ To skip the simulation step and directly submit the program for proof generation

### Use NetworkProver directly

By using the `sp1_sdk::NetworkProver` struct directly, you can call async functions directly and have programmatic access to the proof ID.
By using the `sp1_sdk::NetworkProver` struct directly, you can call async functions directly and have programmatic access to the proof ID and download proofs by ID.

```rust,noplayground
impl NetworkProver {
/// Creates a new [NetworkProver] with the private key set in `SP1_PRIVATE_KEY`.
pub fn new() -> Self;
/// Creates a new [NetworkProver] with the given private key.
pub fn new_from_key(private_key: &str) -> Self;
pub fn new_from_key(private_key: &str);
/// Requests a proof from the prover network, returning the proof ID.
pub async fn request_proof(
Expand All @@ -56,10 +56,28 @@ impl NetworkProver {
mode: ProofMode,
) -> Result<String>;
/// Waits for a proof to be generated and returns the proof.
pub async fn wait_proof<P: DeserializeOwned>(&self, proof_id: &str) -> Result<P>;
/// Waits for a proof to be generated and returns the proof. If a timeout is supplied, the
/// function will return an error if the proof is not generated within the timeout.
pub async fn wait_proof(
&self,
proof_id: &str,
timeout: Option<Duration>,
) -> Result<SP1ProofWithPublicValues>;
/// Get the status and the proof if available of a given proof request. The proof is returned
/// only if the status is Fulfilled.
pub async fn get_proof_status(
&self,
proof_id: &str,
) -> Result<(GetProofStatusResponse, Option<SP1ProofWithPublicValues>)>;
/// Requests a proof from the prover network and waits for it to be generated.
pub async fn prove<P: ProofType>(&self, elf: &[u8], stdin: SP1Stdin) -> Result<P>;
pub async fn prove(
&self,
elf: &[u8],
stdin: SP1Stdin,
mode: ProofMode,
timeout: Option<Duration>,
) -> Result<SP1ProofWithPublicValues>;
}
```
28 changes: 18 additions & 10 deletions crates/sdk/src/network/client.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::{env, time::Duration};

use crate::{
network::auth::NetworkAuth,
network::proto::network::{
ModifyCpuCyclesRequest, ModifyCpuCyclesResponse, UnclaimProofRequest, UnclaimReason,
network::{
auth::NetworkAuth,
proto::network::{
ModifyCpuCyclesRequest, ModifyCpuCyclesResponse, UnclaimProofRequest, UnclaimReason,
},
},
SP1ProofWithPublicValues,
};
use anyhow::{Context, Ok, Result};
use futures::{future::join_all, Future};
use reqwest::{Client as HttpClient, Url};
use reqwest_middleware::ClientWithMiddleware as HttpClientWithMiddleware;
use serde::de::DeserializeOwned;
use sp1_core_machine::io::SP1Stdin;
use std::{
result::Result::Ok as StdOk,
Expand All @@ -29,7 +31,10 @@ use crate::network::proto::network::{
pub const DEFAULT_PROVER_NETWORK_RPC: &str = "https://rpc.succinct.xyz/";

/// The timeout for a proof request to be fulfilled.
const TIMEOUT: Duration = Duration::from_secs(60 * 60);
const PROOF_TIMEOUT: Duration = Duration::from_secs(60 * 60);

/// The timeout for a single RPC request.
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);

pub struct NetworkClient {
pub rpc: TwirpClient,
Expand All @@ -48,6 +53,7 @@ impl NetworkClient {
let auth = NetworkAuth::new(private_key);

let twirp_http_client = HttpClient::builder()
.timeout(REQUEST_TIMEOUT)
.pool_max_idle_per_host(0)
.pool_idle_timeout(Duration::from_secs(240))
.build()
Expand All @@ -58,6 +64,7 @@ impl NetworkClient {
TwirpClient::new(Url::parse(&rpc_url).unwrap(), twirp_http_client, vec![]).unwrap();

let http_client = HttpClient::builder()
.timeout(REQUEST_TIMEOUT)
.pool_max_idle_per_host(0)
.pool_idle_timeout(Duration::from_secs(240))
.build()
Expand All @@ -82,12 +89,12 @@ impl NetworkClient {
Ok(())
}

/// Get the status of a given proof. If the status is ProofFulfilled, the proof is also
/// returned.
pub async fn get_proof_status<P: DeserializeOwned>(
/// Get the status and the proof if available of a given proof request. The proof is returned
/// only if the status is Fulfilled.
pub async fn get_proof_status(
&self,
proof_id: &str,
) -> Result<(GetProofStatusResponse, Option<P>)> {
) -> Result<(GetProofStatusResponse, Option<SP1ProofWithPublicValues>)> {
let res = self
.with_error_handling(
self.rpc.get_proof_status(GetProofStatusRequest { proof_id: proof_id.to_string() }),
Expand All @@ -101,6 +108,7 @@ impl NetworkClient {
let proof_bytes = self
.http
.get(res.proof_url.as_ref().expect("no proof url"))
.timeout(Duration::from_secs(120))
.send()
.await
.context("Failed to send HTTP request for proof")?
Expand Down Expand Up @@ -139,7 +147,7 @@ impl NetworkClient {
) -> Result<String> {
let start = SystemTime::now();
let since_the_epoch = start.duration_since(UNIX_EPOCH).expect("Invalid start time");
let deadline = since_the_epoch.as_secs() + TIMEOUT.as_secs();
let deadline = since_the_epoch.as_secs() + PROOF_TIMEOUT.as_secs();

let nonce = self.get_nonce().await?;
let create_proof_signature = self
Expand Down
43 changes: 38 additions & 5 deletions crates/sdk/src/network/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ use crate::{
Prover, SP1Context, SP1ProofKind, SP1ProofWithPublicValues, SP1ProvingKey, SP1VerifyingKey,
};
use anyhow::Result;
use serde::de::DeserializeOwned;
use sp1_core_machine::io::SP1Stdin;
use sp1_prover::{components::DefaultProverComponents, SP1Prover, SP1_CIRCUIT_VERSION};
use sp1_stark::SP1ProverOpts;

use super::proto::network::GetProofStatusResponse;

use {crate::block_on, tokio::time::sleep};

use crate::provers::{CpuProver, ProofOpts, ProverType};

/// Number of consecutive errors to tolerate before returning an error while polling proof status.
const MAX_CONSECUTIVE_ERRORS: usize = 10;

/// An implementation of [crate::ProverClient] that can generate proofs on a remote RPC server.
pub struct NetworkProver {
client: NetworkClient,
Expand Down Expand Up @@ -71,22 +75,43 @@ impl NetworkProver {

/// Waits for a proof to be generated and returns the proof. If a timeout is supplied, the
/// function will return an error if the proof is not generated within the timeout.
pub async fn wait_proof<P: DeserializeOwned>(
pub async fn wait_proof(
&self,
proof_id: &str,
timeout: Option<Duration>,
) -> Result<P> {
) -> Result<SP1ProofWithPublicValues> {
let client = &self.client;
let mut is_claimed = false;
let start_time = Instant::now();
let mut consecutive_errors = 0;
loop {
if let Some(timeout) = timeout {
if start_time.elapsed() > timeout {
return Err(anyhow::anyhow!("Proof generation timed out."));
}
}

let (status, maybe_proof) = client.get_proof_status::<P>(proof_id).await?;
let result = client.get_proof_status(proof_id).await;

if let Err(e) = result {
consecutive_errors += 1;
log::warn!(
"Failed to get proof status ({}/{}): {:?}",
consecutive_errors,
MAX_CONSECUTIVE_ERRORS,
e
);
if consecutive_errors == MAX_CONSECUTIVE_ERRORS {
return Err(anyhow::anyhow!(
"Proof generation failed: {} consecutive errors.",
MAX_CONSECUTIVE_ERRORS
));
}
continue;
}
consecutive_errors = 0;

let (status, maybe_proof) = result.unwrap();

match status.status() {
ProofStatus::ProofFulfilled => {
Expand All @@ -110,6 +135,15 @@ impl NetworkProver {
}
}

/// Get the status and the proof if available of a given proof request. The proof is returned
/// only if the status is Fulfilled.
pub async fn get_proof_status(
&self,
proof_id: &str,
) -> Result<(GetProofStatusResponse, Option<SP1ProofWithPublicValues>)> {
self.client.get_proof_status(proof_id).await
}

/// Requests a proof from the prover network and waits for it to be generated.
pub async fn prove(
&self,
Expand Down Expand Up @@ -157,7 +191,6 @@ impl Default for NetworkProver {

/// Warns if `opts` or `context` are not default values, since they are currently unsupported.
fn warn_if_not_default(opts: &SP1ProverOpts, context: &SP1Context) {
let _guard = tracing::warn_span!("network_prover").entered();
if opts != &SP1ProverOpts::default() {
tracing::warn!("non-default opts will be ignored: {:?}", opts.core_opts);
tracing::warn!("custom SP1ProverOpts are currently unsupported by the network prover");
Expand Down

0 comments on commit 3d0c989

Please sign in to comment.