From f1af18d20f1f76d18eeff6fe7df1931a0e1b6728 Mon Sep 17 00:00:00 2001 From: Petar Vujovic Date: Wed, 17 Jul 2024 10:28:05 +0200 Subject: [PATCH] fix(core,host,lib): handle async boundary --- core/src/interfaces.rs | 4 ++-- core/src/lib.rs | 8 ++++++-- core/src/prover.rs | 4 ++-- host/src/proof.rs | 14 ++++++++------ host/src/server/api/v1/proof.rs | 1 + lib/src/prover.rs | 6 +++--- 6 files changed, 22 insertions(+), 15 deletions(-) diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index f75fb98bd..969ba7848 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -160,7 +160,7 @@ impl ProofType { input: GuestInput, output: &GuestOutput, config: &Value, - store: &mut dyn IdWrite, + store: Option<&mut dyn IdWrite>, ) -> RaikoResult { let mut proof = match self { ProofType::Native => NativeProver::run(input.clone(), output, config, store) @@ -212,7 +212,7 @@ impl ProofType { pub async fn cancel_proof( &self, proof_key: ProofKey, - read: &mut dyn IdStore, + read: Box<&mut dyn IdStore>, ) -> RaikoResult<()> { let _ = match self { ProofType::Native => NativeProver::cancel(proof_key.clone(), read) diff --git a/core/src/lib.rs b/core/src/lib.rs index 3705b1051..fd6b0666b 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -105,7 +105,7 @@ impl Raiko { &self, input: GuestInput, output: &GuestOutput, - store: &mut dyn IdWrite, + store: Option<&mut dyn IdWrite>, ) -> RaikoResult { let config = serde_json::to_value(&self.request)?; self.request @@ -114,7 +114,11 @@ impl Raiko { .await } - pub async fn cancel(&self, proof_key: ProofKey, read: &mut dyn IdStore) -> RaikoResult<()> { + pub async fn cancel( + &self, + proof_key: ProofKey, + read: Box<&mut dyn IdStore>, + ) -> RaikoResult<()> { self.request.proof_type.cancel_proof(proof_key, read).await } } diff --git a/core/src/prover.rs b/core/src/prover.rs index ac8ad6326..4ac1a512d 100644 --- a/core/src/prover.rs +++ b/core/src/prover.rs @@ -28,7 +28,7 @@ impl Prover for NativeProver { input: GuestInput, output: &GuestOutput, config: &ProverConfig, - _store: &mut dyn IdWrite, + _store: Option<&mut dyn IdWrite>, ) -> ProverResult { let param = config @@ -64,7 +64,7 @@ impl Prover for NativeProver { }) } - async fn cancel(_proof_key: ProofKey, _read: &mut dyn IdStore) -> ProverResult<()> { + async fn cancel(_proof_key: ProofKey, _read: Box<&mut dyn IdStore>) -> ProverResult<()> { Ok(()) } } diff --git a/host/src/proof.rs b/host/src/proof.rs index 525bfe543..e51d3f6f1 100644 --- a/host/src/proof.rs +++ b/host/src/proof.rs @@ -10,7 +10,7 @@ use raiko_lib::{ prover::{IdWrite, Proof}, Measurement, }; -use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskStatus}; +use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrapper, TaskStatus}; use tokio::{ select, sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore}, @@ -59,7 +59,7 @@ impl ProofActor { let mut manager = get_task_manager(&self.opts.clone().into()); key.proof_system - .cancel_proof((key.chain_id, key.blockhash), &mut manager) + .cancel_proof((key.chain_id, key.blockhash), Box::new(&mut manager)) .await?; task.cancel(); Ok(()) @@ -120,7 +120,9 @@ impl ProofActor { while let Some(message) = self.receiver.recv().await { match message { Message::Cancel(key) => { - self.cancel_task(key).await; + if let Err(error) = self.cancel_task(key).await { + error!("Failed to cancel task: {error}") + } } Message::Task(proof_request) => { let permit = Arc::clone(&semaphore) @@ -154,7 +156,7 @@ impl ProofActor { .await?; let (status, proof) = - match handle_proof(&proof_request, opts, chain_specs, &mut manager).await { + match handle_proof(&proof_request, opts, chain_specs, Some(&mut manager)).await { Err(error) => { error!("{error}"); (error.into(), None) @@ -173,7 +175,7 @@ pub async fn handle_proof( proof_request: &ProofRequest, opts: &Opts, chain_specs: &SupportedChainSpecs, - store: &mut dyn IdWrite, + store: Option<&mut TaskManagerWrapper>, ) -> HostResult { info!( "# Generating proof for block {} on {}", @@ -227,7 +229,7 @@ pub async fn handle_proof( memory::reset_stats(); let measurement = Measurement::start("Generating proof...", false); let proof = raiko - .prove(input.clone(), &output, store) + .prove(input.clone(), &output, store.map(|s| s as &mut dyn IdWrite)) .await .map_err(|e| { let total_time = total_time.stop_with("====> Proof generation failed"); diff --git a/host/src/server/api/v1/proof.rs b/host/src/server/api/v1/proof.rs index 676d98a31..dabdcbc0a 100644 --- a/host/src/server/api/v1/proof.rs +++ b/host/src/server/api/v1/proof.rs @@ -46,6 +46,7 @@ async fn proof_handler( &proof_request, &prover_state.opts, &prover_state.chain_specs, + None, ) .await .map_err(|e| { diff --git a/lib/src/prover.rs b/lib/src/prover.rs index 03e234da1..33ef12b80 100644 --- a/lib/src/prover.rs +++ b/lib/src/prover.rs @@ -37,7 +37,7 @@ pub struct Proof { pub kzg_proof: Option, } -pub trait IdWrite { +pub trait IdWrite: Send { fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()>; fn remove_id(&mut self, key: ProofKey) -> ProverResult<()>; @@ -53,8 +53,8 @@ pub trait Prover { input: GuestInput, output: &GuestOutput, config: &ProverConfig, - store: &mut dyn IdWrite, + store: Option<&mut dyn IdWrite>, ) -> ProverResult; - async fn cancel(proof_key: ProofKey, read: &mut dyn IdStore) -> ProverResult<()>; + async fn cancel(proof_key: ProofKey, read: Box<&mut dyn IdStore>) -> ProverResult<()>; }