Skip to content

Commit

Permalink
feat(raiko): refine auto-scaling (#346)
Browse files Browse the repository at this point in the history
* 1. use feature to enable auto-scaling
2. give enough time to heat up bonsai
3. use ref count to avoid shutdown with ongoing tasks.

* fix compile
  • Loading branch information
smtmfft authored Aug 19, 2024
1 parent dc89e60 commit 34c1348
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 32 deletions.
2 changes: 1 addition & 1 deletion provers/risc0/driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ enable = [
cuda = ["risc0-zkvm?/cuda"]
metal = ["risc0-zkvm?/metal"]
bench = []

bonsai-auto-scaling = []
6 changes: 6 additions & 0 deletions provers/risc0/driver/src/bonsai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::{

use crate::Risc0Param;

#[cfg(feature = "bonsai-auto-scaling")]
pub mod auto_scaling;

pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
Expand Down Expand Up @@ -118,6 +119,10 @@ pub async fn maybe_prove<I: Serialize, O: Eq + Debug + Serialize + DeserializeOw
info!("Loaded locally cached stark receipt {receipt_label:?}");
(cached_data.0, cached_data.1, true)
} else if param.bonsai {
#[cfg(feature = "bonsai-auto-scaling")]
auto_scaling::maxpower_bonsai()
.await
.expect("Failed to set max power on Bonsai");
// query bonsai service until it works
loop {
match prove_bonsai(
Expand Down Expand Up @@ -196,6 +201,7 @@ pub async fn cancel_proof(uuid: String) -> anyhow::Result<()> {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let session = bonsai_sdk::alpha::SessionId { uuid };
session.stop(&client)?;
#[cfg(feature = "bonsai-auto-scaling")]
auto_scaling::shutdown_bonsai().await?;
Ok(())
}
Expand Down
57 changes: 40 additions & 17 deletions provers/risc0/driver/src/bonsai/auto_scaling.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
#![cfg(feature = "bonsai-auto-scaling")]

use anyhow::{Error, Ok, Result};
use lazy_static::lazy_static;
use log::info;
use once_cell::sync::Lazy;
use reqwest::{header::HeaderMap, header::HeaderValue, header::CONTENT_TYPE, Client};
use serde::Deserialize;
use std::env;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, error as trace_err};

#[derive(Debug, Deserialize, Default)]
Expand Down Expand Up @@ -118,9 +124,20 @@ lazy_static! {
.unwrap();
}

static AUTO_SCALER: Lazy<Arc<Mutex<BonsaiAutoScaler>>> = Lazy::new(|| {
Arc::new(Mutex::new(BonsaiAutoScaler::new(
BONSAI_API_URL.to_string(),
BONSAI_API_KEY.to_string(),
)))
});

static REF_COUNT: Lazy<Arc<Mutex<u32>>> = Lazy::new(|| Arc::new(Mutex::new(0)));

pub(crate) async fn maxpower_bonsai() -> Result<()> {
let mut auto_scaler =
BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string());
let mut ref_count = REF_COUNT.lock().await;
*ref_count += 1;

let mut auto_scaler = AUTO_SCALER.lock().await;
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
// either already maxed out or pending to be maxed out
if current_gpu_num.current == *MAX_BONSAI_GPU_NUM
Expand All @@ -129,22 +146,31 @@ pub(crate) async fn maxpower_bonsai() -> Result<()> {
{
Ok(())
} else {
info!("setting bonsai gpu num to: {:?}", *MAX_BONSAI_GPU_NUM);
auto_scaler.set_bonsai_gpu_num(*MAX_BONSAI_GPU_NUM).await?;
auto_scaler.wait_for_bonsai_config_active(300).await
auto_scaler.wait_for_bonsai_config_active(900).await
}
}

pub(crate) async fn shutdown_bonsai() -> Result<()> {
let mut auto_scaler =
BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string());
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
if current_gpu_num.current == 0 && current_gpu_num.pending == 0 && current_gpu_num.desired == 0
{
Ok(())
let mut ref_count = REF_COUNT.lock().await;
*ref_count = ref_count.saturating_sub(1);

if *ref_count == 0 {
let mut auto_scaler = AUTO_SCALER.lock().await;
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
if current_gpu_num.current == 0
&& current_gpu_num.desired == 0
&& current_gpu_num.pending == 0
{
Ok(())
} else {
info!("setting bonsai gpu num to: 0");
auto_scaler.set_bonsai_gpu_num(0).await?;
auto_scaler.wait_for_bonsai_config_active(90).await
}
} else {
auto_scaler.set_bonsai_gpu_num(0).await?;
// wait few minute for the bonsai to cool down
auto_scaler.wait_for_bonsai_config_active(30).await
Ok(())
}
}

Expand Down Expand Up @@ -184,7 +210,7 @@ mod test {
.await
.expect("Failed to set bonsai gpu num");
auto_scaler
.wait_for_bonsai_config_active(300)
.wait_for_bonsai_config_active(600)
.await
.unwrap();
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current;
Expand All @@ -194,10 +220,7 @@ mod test {
.set_bonsai_gpu_num(0)
.await
.expect("Failed to set bonsai gpu num");
auto_scaler
.wait_for_bonsai_config_active(300)
.await
.unwrap();
auto_scaler.wait_for_bonsai_config_active(60).await.unwrap();
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current;
assert_eq!(current_gpu_num, 0);
}
Expand Down
22 changes: 8 additions & 14 deletions provers/risc0/driver/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
#![cfg(feature = "enable")]

#[cfg(feature = "bonsai-auto-scaling")]
use crate::bonsai::auto_scaling::shutdown_bonsai;
use crate::{
methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID},
snarks::verify_groth16_snark,
};
use alloy_primitives::B256;
use hex::ToHex;
use log::warn;
Expand All @@ -13,12 +19,6 @@ use serde_with::serde_as;
use std::fmt::Debug;
use tracing::{debug, info as traicing_info};

use crate::{
bonsai::auto_scaling::{maxpower_bonsai, shutdown_bonsai},
methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID},
snarks::verify_groth16_snark,
};

pub use bonsai::*;

pub mod bonsai;
Expand Down Expand Up @@ -71,13 +71,6 @@ impl Prover for Risc0Prover {
debug!("elf code length: {}", RISC0_GUEST_ELF.len());
let encoded_input = to_vec(&input).expect("Could not serialize proving input!");

if config.bonsai {
// make max speed bonsai
maxpower_bonsai()
.await
.expect("Failed to set max power on Bonsai");
}

let result = maybe_prove::<GuestInput, B256>(
&config,
encoded_input,
Expand Down Expand Up @@ -116,8 +109,9 @@ impl Prover for Risc0Prover {
journal
};

#[cfg(feature = "bonsai-auto-scaling")]
if config.bonsai {
// shutdown max speed bonsai
// shutdown bonsai
shutdown_bonsai()
.await
.map_err(|e| ProverError::GuestError(e.to_string()))?;
Expand Down

0 comments on commit 34c1348

Please sign in to comment.