diff --git a/.env.example b/.env.example index dc06a2d..b50fa09 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,7 @@ DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110b # example: phi3:3.8b,gpt-4o-mini DKN_MODELS= + ## DRIA (optional) ## # P2P address, you don't need to change this unless this port is already in use. DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001 @@ -16,6 +17,8 @@ DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001 DKN_RELAY_NODES= # Comma-separated static bootstrap nodes DKN_BOOTSTRAP_NODES= +# Batch size for workflows, you do not need to edit this. +DKN_BATCH_SIZE= ## DRIA (profiling only, do not uncomment) ## # Set to a number of seconds to wait before exiting, only use in profiling build! diff --git a/Cargo.lock b/Cargo.lock index 88c22c0..d50acb8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -974,7 +974,7 @@ dependencies = [ [[package]] name = "dkn-compute" -version = "0.2.26" +version = "0.2.27" dependencies = [ "async-trait", "base64 0.22.1", @@ -1007,7 +1007,7 @@ dependencies = [ [[package]] name = "dkn-monitor" -version = "0.2.26" +version = "0.2.27" dependencies = [ "async-trait", "dkn-compute", @@ -1027,7 +1027,7 @@ dependencies = [ [[package]] name = "dkn-p2p" -version = "0.2.26" +version = "0.2.27" dependencies = [ "dkn-utils", "env_logger 0.11.5", @@ -1041,11 +1041,11 @@ dependencies = [ [[package]] name = "dkn-utils" -version = "0.2.26" +version = "0.2.27" [[package]] name = "dkn-workflows" -version = "0.2.26" +version = "0.2.27" dependencies = [ "dkn-utils", "dotenvy", diff --git a/Cargo.toml b/Cargo.toml index 033ad29..406743a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ default-members = ["compute"] [workspace.package] edition = "2021" -version = "0.2.26" +version = "0.2.27" license = "Apache-2.0" readme = "README.md" diff --git a/compute/src/config.rs b/compute/src/config.rs index 85db998..23fa8c0 100644 --- a/compute/src/config.rs +++ b/compute/src/config.rs @@ -9,6 +9,9 @@ use libsecp256k1::{PublicKey, SecretKey}; use std::{env, str::FromStr}; +// TODO: make this configurable later +const DEFAULT_WORKFLOW_BATCH_SIZE: usize = 5; + #[derive(Debug, Clone)] pub struct DriaComputeNodeConfig { /// Wallet secret/private key. @@ -25,6 +28,11 @@ pub struct DriaComputeNodeConfig { pub workflows: DriaWorkflowsConfig, /// Network type of the node. pub network_type: DriaNetworkType, + /// Batch size for batchable workflows. + /// + /// A higher value will help execute more tasks concurrently, + /// at the risk of hitting rate-limits. + pub batch_size: usize, } /// The default P2P network listen address. @@ -103,6 +111,11 @@ impl DriaComputeNodeConfig { .map(|s| DriaNetworkType::from(s.as_str())) .unwrap_or_default(); + // parse batch size + let batch_size = env::var("DKN_BATCH_SIZE") + .map(|s| s.parse::().unwrap_or(DEFAULT_WORKFLOW_BATCH_SIZE)) + .unwrap_or(DEFAULT_WORKFLOW_BATCH_SIZE); + Self { admin_public_key, secret_key, @@ -111,6 +124,7 @@ impl DriaComputeNodeConfig { workflows, p2p_listen_addr, network_type, + batch_size, } } diff --git a/compute/src/handlers/workflow.rs b/compute/src/handlers/workflow.rs index 340d035..ae651e9 100644 --- a/compute/src/handlers/workflow.rs +++ b/compute/src/handlers/workflow.rs @@ -129,11 +129,7 @@ impl WorkflowHandler { // convert payload to message let payload_str = serde_json::json!(payload).to_string(); - log::debug!( - "Publishing result for task {}\n{}", - task.task_id, - payload_str - ); + log::info!("Publishing result for task {}", task.task_id); DriaMessage::new(payload_str, Self::RESPONSE_TOPIC) } Err(err) => { @@ -161,7 +157,7 @@ impl WorkflowHandler { // try publishing the result if let Err(publish_err) = node.publish(message).await { - let err_msg = format!("could not publish result: {:?}", publish_err); + let err_msg = format!("Could not publish task result: {:?}", publish_err); log::error!("{}", err_msg); let payload = serde_json::json!({ diff --git a/compute/src/main.rs b/compute/src/main.rs index c90483b..f0e1db7 100644 --- a/compute/src/main.rs +++ b/compute/src/main.rs @@ -3,6 +3,7 @@ use dkn_workflows::DriaWorkflowsConfig; use eyre::Result; use std::env; use tokio_util::{sync::CancellationToken, task::TaskTracker}; +use workers::workflow::WorkflowsWorker; #[tokio::main] async fn main() -> Result<()> { @@ -86,6 +87,7 @@ async fn main() -> Result<()> { log::warn!("Using models: {:#?}", config.workflows.models); // create the node + let batch_size = config.batch_size; let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?; // spawn p2p client first @@ -94,14 +96,21 @@ async fn main() -> Result<()> { // spawn batch worker thread if we are using such models (e.g. OpenAI, Gemini, OpenRouter) if let Some(mut worker_batch) = worker_batch { - log::info!("Spawning workflows batch worker thread."); - task_tracker.spawn(async move { worker_batch.run_batch().await }); + assert!( + batch_size <= WorkflowsWorker::MAX_BATCH_SIZE, + "batch size too large" + ); + log::info!( + "Spawning workflows batch worker thread. (batch size {})", + batch_size + ); + task_tracker.spawn(async move { worker_batch.run_batch(batch_size).await }); } // spawn single worker thread if we are using such models (e.g. Ollama) if let Some(mut worker_single) = worker_single { log::info!("Spawning workflows single worker thread."); - task_tracker.spawn(async move { worker_single.run().await }); + task_tracker.spawn(async move { worker_single.run_series().await }); } // spawn compute node thread diff --git a/compute/src/node.rs b/compute/src/node.rs index bd41fe0..985d7e0 100644 --- a/compute/src/node.rs +++ b/compute/src/node.rs @@ -81,9 +81,7 @@ impl DriaComputeNode { let (p2p_client, p2p_commander, message_rx) = DriaP2PClient::new( keypair, config.p2p_listen_addr.clone(), - available_nodes.bootstrap_nodes.clone().into_iter(), - available_nodes.relay_nodes.clone().into_iter(), - available_nodes.rpc_nodes.clone().into_iter(), + &available_nodes, protocol, )?; diff --git a/compute/src/workers/workflow.rs b/compute/src/workers/workflow.rs index 847332f..0a33c30 100644 --- a/compute/src/workers/workflow.rs +++ b/compute/src/workers/workflow.rs @@ -32,7 +32,9 @@ pub struct WorkflowsWorkerOutput { /// /// It is expected to be spawned in another thread, with `run_batch` for batch processing and `run` for single processing. pub struct WorkflowsWorker { + /// Workflow message channel receiver, the sender is most likely the compute node itself. workflow_rx: mpsc::Receiver, + /// Publish message channel sender, the receiver is most likely the compute node itself. publish_tx: mpsc::Sender, } @@ -40,10 +42,11 @@ pub struct WorkflowsWorker { const WORKFLOW_CHANNEL_BUFSIZE: usize = 1024; impl WorkflowsWorker { - /// Batch size that defines how many tasks can be executed in parallel at once. - /// IMPORTANT NOTE: `run` function is designed to handle the batch size here specifically, + /// Batch size that defines how many tasks can be executed concurrently at once. + /// + /// The `run` function is designed to handle the batch size here specifically, /// if there are more tasks than the batch size, the function will panic. - const BATCH_SIZE: usize = 8; + pub const MAX_BATCH_SIZE: usize = 8; /// Creates a worker and returns the sender and receiver for the worker. pub fn new( @@ -65,24 +68,20 @@ impl WorkflowsWorker { self.workflow_rx.close(); } - /// Launches the thread that can process tasks one by one. + /// Launches the thread that can process tasks one by one (in series). /// This function will block until the channel is closed. /// /// It is suitable for task streams that consume local resources, unlike API calls. - pub async fn run(&mut self) { + pub async fn run_series(&mut self) { loop { let task = self.workflow_rx.recv().await; - let result = if let Some(task) = task { + if let Some(task) = task { log::info!("Processing single workflow for task {}", task.task_id); - WorkflowsWorker::execute(task).await + WorkflowsWorker::execute((task, self.publish_tx.clone())).await } else { return self.shutdown(); }; - - if let Err(e) = self.publish_tx.send(result).await { - log::error!("Error sending workflow result: {}", e); - } } } @@ -91,13 +90,16 @@ impl WorkflowsWorker { /// /// It is suitable for task streams that make use of API calls, unlike Ollama-like /// tasks that consumes local resources and would not make sense to run in parallel. - pub async fn run_batch(&mut self) { + /// + /// Batch size must NOT be larger than `MAX_BATCH_SIZE`, otherwise will panic. + pub async fn run_batch(&mut self, batch_size: usize) { + // TODO: need some better batch_size error handling here loop { // get tasks in batch from the channel let mut task_buffer = Vec::new(); let num_tasks = self .workflow_rx - .recv_many(&mut task_buffer, Self::BATCH_SIZE) + .recv_many(&mut task_buffer, batch_size) .await; if num_tasks == 0 { @@ -106,8 +108,10 @@ impl WorkflowsWorker { // process the batch log::info!("Processing {} workflows in batch", num_tasks); - let mut batch = task_buffer.into_iter(); - let results = match num_tasks { + let mut batch = task_buffer + .into_iter() + .map(|b| (b, self.publish_tx.clone())); + match num_tasks { 1 => { let r0 = WorkflowsWorker::execute(batch.next().unwrap()).await; vec![r0] @@ -186,23 +190,17 @@ impl WorkflowsWorker { unreachable!( "number of tasks cant be larger than batch size ({} > {})", num_tasks, - Self::BATCH_SIZE + Self::MAX_BATCH_SIZE ); } }; - - // publish all results - log::info!("Publishing {} workflow results", results.len()); - for result in results { - if let Err(e) = self.publish_tx.send(result).await { - log::error!("Error sending workflow result: {}", e); - } - } } } - /// A single task execution. - pub async fn execute(input: WorkflowsWorkerInput) -> WorkflowsWorkerOutput { + /// Executes a single task, and publishes the output. + pub async fn execute( + (input, publish_tx): (WorkflowsWorkerInput, mpsc::Sender), + ) { let mut memory = ProgramMemory::new(); let started_at = std::time::Instant::now(); @@ -211,13 +209,17 @@ impl WorkflowsWorker { .execute(input.entry.as_ref(), &input.workflow, &mut memory) .await; - WorkflowsWorkerOutput { + let output = WorkflowsWorkerOutput { result, public_key: input.public_key, task_id: input.task_id, model_name: input.model_name, batchable: input.batchable, stats: input.stats.record_execution_time(started_at), + }; + + if let Err(e) = publish_tx.send(output).await { + log::error!("Error sending workflow result: {}", e); } } } diff --git a/monitor/src/main.rs b/monitor/src/main.rs index 5603ddc..bf10c3d 100644 --- a/monitor/src/main.rs +++ b/monitor/src/main.rs @@ -33,9 +33,7 @@ async fn main() -> eyre::Result<()> { let (client, commander, msg_rx) = DriaP2PClient::new( keypair, listen_addr, - nodes.bootstrap_nodes.into_iter(), - nodes.relay_nodes.into_iter(), - nodes.rpc_nodes.into_iter(), + &nodes, DriaP2PProtocol::new_major_minor(network.protocol_name()), )?; diff --git a/p2p/src/client.rs b/p2p/src/client.rs index 7d0c0ee..92ecb88 100644 --- a/p2p/src/client.rs +++ b/p2p/src/client.rs @@ -10,7 +10,7 @@ use std::time::Duration; use tokio::sync::mpsc; use crate::behaviour::{DriaBehaviour, DriaBehaviourEvent}; -use crate::DriaP2PProtocol; +use crate::{DriaNodes, DriaP2PProtocol}; use super::commands::DriaP2PCommand; use super::DriaP2PCommander; @@ -46,9 +46,7 @@ impl DriaP2PClient { pub fn new( keypair: Keypair, listen_addr: Multiaddr, - bootstraps: impl Iterator, - relays: impl Iterator, - rpcs: impl Iterator, + nodes: &DriaNodes, protocol: DriaP2PProtocol, ) -> Result<( DriaP2PClient, @@ -89,7 +87,7 @@ impl DriaP2PClient { .set_mode(Some(libp2p::kad::Mode::Server)); // initiate bootstrap - for addr in bootstraps { + for addr in &nodes.bootstrap_nodes { log::info!("Dialling bootstrap: {:#?}", addr); if let Some(peer_id) = addr.iter().find_map(|p| match p { Protocol::P2p(peer_id) => Some(peer_id), @@ -97,7 +95,10 @@ impl DriaP2PClient { }) { swarm.dial(addr.clone())?; log::info!("Adding {} to Kademlia routing table", addr); - swarm.behaviour_mut().kademlia.add_address(&peer_id, addr); + swarm + .behaviour_mut() + .kademlia + .add_address(&peer_id, addr.clone()); } else { log::warn!("Missing peerID in address: {}", addr); } @@ -115,17 +116,29 @@ impl DriaP2PClient { // listen on all interfaces for incoming connections log::info!("Listening p2p network on: {}", listen_addr); swarm.listen_on(listen_addr)?; - for addr in relays { + + // listen on relay addresses with p2p circuit + for addr in &nodes.relay_nodes { log::info!("Listening to relay: {}", addr); swarm.listen_on(addr.clone().with(Protocol::P2pCircuit))?; } // dial rpc nodes - for rpc_addr in rpcs { + for rpc_addr in &nodes.rpc_nodes { log::info!("Dialing RPC node: {}", rpc_addr); - swarm.dial(rpc_addr)?; + swarm.dial(rpc_addr.clone())?; } + // add rpcs as explicit peers + // TODO: may not be necessary + // for rpc_peer_id in &nodes.rpc_peerids { + // log::info!("Adding {} as explicit peer.", rpc_peer_id); + // swarm + // .behaviour_mut() + // .gossipsub + // .add_explicit_peer(rpc_peer_id); + // } + // create commander let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_CHANNEL_BUFSIZE); let commander = DriaP2PCommander::new(cmd_tx, protocol.clone()); diff --git a/p2p/src/nodes.rs b/p2p/src/nodes.rs index efdde4b..1b062cf 100644 --- a/p2p/src/nodes.rs +++ b/p2p/src/nodes.rs @@ -29,6 +29,26 @@ impl DriaNodes { } } + pub fn with_relay_nodes(mut self, addresses: impl IntoIterator) -> Self { + self.relay_nodes.extend(addresses); + self + } + + pub fn with_bootstrap_nodes(mut self, addresses: impl IntoIterator) -> Self { + self.bootstrap_nodes.extend(addresses); + self + } + + pub fn with_rpc_nodes(mut self, addresses: impl IntoIterator) -> Self { + self.rpc_nodes.extend(addresses); + self + } + + pub fn with_rpc_peer_ids(mut self, addresses: impl IntoIterator) -> Self { + self.rpc_peerids.extend(addresses); + self + } + /// Parses static bootstrap & relay nodes from environment variables. /// /// The environment variables are: diff --git a/p2p/tests/listen_test.rs b/p2p/tests/listen_test.rs index 4b52619..bfeafd6 100644 --- a/p2p/tests/listen_test.rs +++ b/p2p/tests/listen_test.rs @@ -1,4 +1,4 @@ -use dkn_p2p::{DriaP2PClient, DriaP2PProtocol}; +use dkn_p2p::{DriaNodes, DriaP2PClient, DriaP2PProtocol}; use eyre::Result; use libp2p_identity::Keypair; @@ -12,13 +12,18 @@ async fn test_listen_topic_once() -> Result<()> { .is_test(true) .try_init(); + let listen_addr = "/ip4/0.0.0.0/tcp/4001".parse()?; + + // prepare nodes + let nodes = DriaNodes::new(dkn_p2p::DriaNetworkType::Community) + .with_bootstrap_nodes(["/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4".parse()?]) + .with_relay_nodes(["/ip4/34.201.33.141/tcp/4001/p2p/16Uiu2HAkuXiV2CQkC9eJgU6cMnJ9SMARa85FZ6miTkvn5fuHNufa".parse()?]); + // spawn P2P client in another task let (client, mut commander, mut msg_rx) = DriaP2PClient::new( Keypair::generate_secp256k1(), - "/ip4/0.0.0.0/tcp/4001".parse()?, - vec!["/ip4/44.206.245.139/tcp/4001/p2p/16Uiu2HAm4q3LZU2T9kgjKK4ysy6KZYKLq8KiXQyae4RHdF7uqSt4".parse()?].into_iter(), - vec!["/ip4/34.201.33.141/tcp/4001/p2p/16Uiu2HAkuXiV2CQkC9eJgU6cMnJ9SMARa85FZ6miTkvn5fuHNufa".parse()?].into_iter(), - vec![].into_iter(), + listen_addr, + &nodes, DriaP2PProtocol::default(), ) .expect("could not create p2p client");