From 34c1348cb3f001638488c74c5fded0b2a38c101e Mon Sep 17 00:00:00 2001 From: smtmfft <99081233+smtmfft@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:56:00 +0800 Subject: [PATCH] feat(raiko): refine auto-scaling (#346) * 1. use feature to enable auto-scaling 2. give enough time to heat up bonsai 3. use ref count to avoid shutdown with ongoing tasks. * fix compile --- provers/risc0/driver/Cargo.toml | 2 +- provers/risc0/driver/src/bonsai.rs | 6 ++ .../risc0/driver/src/bonsai/auto_scaling.rs | 57 +++++++++++++------ provers/risc0/driver/src/lib.rs | 22 +++---- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/provers/risc0/driver/Cargo.toml b/provers/risc0/driver/Cargo.toml index 91abe106c..fe696f470 100644 --- a/provers/risc0/driver/Cargo.toml +++ b/provers/risc0/driver/Cargo.toml @@ -67,4 +67,4 @@ enable = [ cuda = ["risc0-zkvm?/cuda"] metal = ["risc0-zkvm?/metal"] bench = [] - +bonsai-auto-scaling = [] \ No newline at end of file diff --git a/provers/risc0/driver/src/bonsai.rs b/provers/risc0/driver/src/bonsai.rs index 5f57bfc92..4e209df6e 100644 --- a/provers/risc0/driver/src/bonsai.rs +++ b/provers/risc0/driver/src/bonsai.rs @@ -16,6 +16,7 @@ use std::{ use crate::Risc0Param; +#[cfg(feature = "bonsai-auto-scaling")] pub mod auto_scaling; pub async fn verify_bonsai_receipt( @@ -118,6 +119,10 @@ pub async fn maybe_prove anyhow::Result<()> { let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?; let session = bonsai_sdk::alpha::SessionId { uuid }; session.stop(&client)?; + #[cfg(feature = "bonsai-auto-scaling")] auto_scaling::shutdown_bonsai().await?; Ok(()) } diff --git a/provers/risc0/driver/src/bonsai/auto_scaling.rs b/provers/risc0/driver/src/bonsai/auto_scaling.rs index e67808813..c3753e3a2 100644 --- a/provers/risc0/driver/src/bonsai/auto_scaling.rs +++ b/provers/risc0/driver/src/bonsai/auto_scaling.rs @@ -1,8 +1,14 @@ +#![cfg(feature = "bonsai-auto-scaling")] + use anyhow::{Error, Ok, Result}; use lazy_static::lazy_static; +use log::info; +use once_cell::sync::Lazy; use reqwest::{header::HeaderMap, header::HeaderValue, header::CONTENT_TYPE, Client}; use serde::Deserialize; use std::env; +use std::sync::Arc; +use tokio::sync::Mutex; use tracing::{debug, error as trace_err}; #[derive(Debug, Deserialize, Default)] @@ -118,9 +124,20 @@ lazy_static! { .unwrap(); } +static AUTO_SCALER: Lazy>> = Lazy::new(|| { + Arc::new(Mutex::new(BonsaiAutoScaler::new( + BONSAI_API_URL.to_string(), + BONSAI_API_KEY.to_string(), + ))) +}); + +static REF_COUNT: Lazy>> = Lazy::new(|| Arc::new(Mutex::new(0))); + pub(crate) async fn maxpower_bonsai() -> Result<()> { - let mut auto_scaler = - BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string()); + let mut ref_count = REF_COUNT.lock().await; + *ref_count += 1; + + let mut auto_scaler = AUTO_SCALER.lock().await; let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?; // either already maxed out or pending to be maxed out if current_gpu_num.current == *MAX_BONSAI_GPU_NUM @@ -129,22 +146,31 @@ pub(crate) async fn maxpower_bonsai() -> Result<()> { { Ok(()) } else { + info!("setting bonsai gpu num to: {:?}", *MAX_BONSAI_GPU_NUM); auto_scaler.set_bonsai_gpu_num(*MAX_BONSAI_GPU_NUM).await?; - auto_scaler.wait_for_bonsai_config_active(300).await + auto_scaler.wait_for_bonsai_config_active(900).await } } pub(crate) async fn shutdown_bonsai() -> Result<()> { - let mut auto_scaler = - BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string()); - let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?; - if current_gpu_num.current == 0 && current_gpu_num.pending == 0 && current_gpu_num.desired == 0 - { - Ok(()) + let mut ref_count = REF_COUNT.lock().await; + *ref_count = ref_count.saturating_sub(1); + + if *ref_count == 0 { + let mut auto_scaler = AUTO_SCALER.lock().await; + let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?; + if current_gpu_num.current == 0 + && current_gpu_num.desired == 0 + && current_gpu_num.pending == 0 + { + Ok(()) + } else { + info!("setting bonsai gpu num to: 0"); + auto_scaler.set_bonsai_gpu_num(0).await?; + auto_scaler.wait_for_bonsai_config_active(90).await + } } else { - auto_scaler.set_bonsai_gpu_num(0).await?; - // wait few minute for the bonsai to cool down - auto_scaler.wait_for_bonsai_config_active(30).await + Ok(()) } } @@ -184,7 +210,7 @@ mod test { .await .expect("Failed to set bonsai gpu num"); auto_scaler - .wait_for_bonsai_config_active(300) + .wait_for_bonsai_config_active(600) .await .unwrap(); let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current; @@ -194,10 +220,7 @@ mod test { .set_bonsai_gpu_num(0) .await .expect("Failed to set bonsai gpu num"); - auto_scaler - .wait_for_bonsai_config_active(300) - .await - .unwrap(); + auto_scaler.wait_for_bonsai_config_active(60).await.unwrap(); let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current; assert_eq!(current_gpu_num, 0); } diff --git a/provers/risc0/driver/src/lib.rs b/provers/risc0/driver/src/lib.rs index 6e9920614..32e1c9b56 100644 --- a/provers/risc0/driver/src/lib.rs +++ b/provers/risc0/driver/src/lib.rs @@ -1,5 +1,11 @@ #![cfg(feature = "enable")] +#[cfg(feature = "bonsai-auto-scaling")] +use crate::bonsai::auto_scaling::shutdown_bonsai; +use crate::{ + methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID}, + snarks::verify_groth16_snark, +}; use alloy_primitives::B256; use hex::ToHex; use log::warn; @@ -13,12 +19,6 @@ use serde_with::serde_as; use std::fmt::Debug; use tracing::{debug, info as traicing_info}; -use crate::{ - bonsai::auto_scaling::{maxpower_bonsai, shutdown_bonsai}, - methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID}, - snarks::verify_groth16_snark, -}; - pub use bonsai::*; pub mod bonsai; @@ -71,13 +71,6 @@ impl Prover for Risc0Prover { debug!("elf code length: {}", RISC0_GUEST_ELF.len()); let encoded_input = to_vec(&input).expect("Could not serialize proving input!"); - if config.bonsai { - // make max speed bonsai - maxpower_bonsai() - .await - .expect("Failed to set max power on Bonsai"); - } - let result = maybe_prove::( &config, encoded_input, @@ -116,8 +109,9 @@ impl Prover for Risc0Prover { journal }; + #[cfg(feature = "bonsai-auto-scaling")] if config.bonsai { - // shutdown max speed bonsai + // shutdown bonsai shutdown_bonsai() .await .map_err(|e| ProverError::GuestError(e.to_string()))?;