From f5d2bdef75b8eae2d891b588b461f663b64c2e64 Mon Sep 17 00:00:00 2001 From: Flix Date: Thu, 28 Dec 2023 19:55:14 +0100 Subject: [PATCH] feat: Implement concurrency limits Also rename runner builder methods. --- examples/context.rs | 2 +- examples/error_handling.rs | 2 +- src/job.rs | 13 ++++++- src/runner.rs | 79 ++++++++++++++++++++++++++++++-------- 4 files changed, 77 insertions(+), 19 deletions(-) diff --git a/examples/context.rs b/examples/context.rs index d5eff18..43ca083 100644 --- a/examples/context.rs +++ b/examples/context.rs @@ -35,7 +35,7 @@ async fn main() -> Result<()> { // Start the job runner to execute jobs from the messages in the queue in the // database. - let job_runner = JobRunner::new(db.clone()).set_context("cats").run::(); + let job_runner = JobRunner::new(db.clone()).with_context("cats").run::(); // Spawn new jobs via a message on the database queue. let job_id = JobRegistry::Greet.builder().spawn(&db).await?; diff --git a/examples/error_handling.rs b/examples/error_handling.rs index 8dd74a9..bedf34c 100644 --- a/examples/error_handling.rs +++ b/examples/error_handling.rs @@ -40,7 +40,7 @@ async fn main() -> Result<()> { let error_received = Arc::new(AtomicBool::new(false)); let err_received = error_received.clone(); let job_runner = JobRunner::new(db.clone()) - .set_error_handler(move |_err| { + .with_error_handler(move |_err| { err_received.store(true, Ordering::SeqCst); }) .run::(); diff --git a/src/job.rs b/src/job.rs index a32ef0b..69ff30e 100644 --- a/src/job.rs +++ b/src/job.rs @@ -1,6 +1,9 @@ //! Provider for job handlers. -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; use serde::{de::DeserializeOwned, Serialize}; use tokio::task::JoinHandle; @@ -113,10 +116,15 @@ impl CurrentJob { } /// Job running function that handles retries as well etc. - pub(crate) fn run(mut self, mut function: JobFunctionType) -> JoinHandle> { + pub(crate) fn run( + mut self, + mut function: JobFunctionType, + currently_running: Arc, + ) -> JoinHandle> { self.keep_alive = Some(Self::keep_alive(self.db.clone(), self.id).into()); let span = tracing::debug_span!("job-run"); + currently_running.fetch_add(1, Ordering::Relaxed); tokio::task::spawn( async move { let id = self.id; @@ -124,6 +132,7 @@ impl CurrentJob { tracing::trace!("Starting job with ID {id}."); let res = function(self).await; + currently_running.fetch_sub(1, Ordering::Relaxed); // Handle the job's error if let Err(err) = res { diff --git a/src/runner.rs b/src/runner.rs index 195f448..2d0628b 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -1,7 +1,16 @@ //! Connector to the database which runs code based on the messages and their //! type. -use std::{fmt::Debug, sync::Arc, time::Duration}; +use std::{ + fmt::Debug, + ops::Range, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + thread::available_parallelism, + time::Duration, +}; use bonsaidb::core::{ async_trait::async_trait, @@ -34,6 +43,9 @@ pub struct JobRunner { error_handler: Option, /// Outside context type-map to provide resources to the jobs. context: Context, + /// Concurrency limits, a range from minimum to maximum concurrent jobs to + /// be targeted in the execution queue. + concurrency: Range, } impl JobRunner @@ -42,12 +54,15 @@ where { /// Create a new job runner on this database. pub fn new(db: DB) -> Self { - Self { db, error_handler: None, context: Context::new() } + let concurrency = available_parallelism() + .map(|num_cpus| usize::from(num_cpus) as u32 / 2..usize::from(num_cpus) as u32 * 2) + .unwrap_or(3_u32..8_u32); + Self { db, error_handler: None, context: Context::new(), concurrency } } /// Set the error handler callback to be called when jobs return an error. #[must_use] - pub fn set_error_handler(mut self, handler: F) -> Self + pub fn with_error_handler(mut self, handler: F) -> Self where F: Fn(Box) + Send + Sync + 'static, { @@ -57,11 +72,18 @@ where /// Add context to the runner. Only one instance per type can be inserted! #[must_use] - pub fn set_context(mut self, context: C) -> Self { + pub fn with_context(mut self, context: C) -> Self { self.context.insert(context); self } + /// Set the concurrency limits. + #[must_use] + pub fn with_concurrency_limits(mut self, min_concurrent: u32, max_concurrent: u32) -> Self { + self.concurrency = min_concurrent..max_concurrent; + self + } + /// Spawn and run the daemon for processing messages/jobs in the background. /// Keep this handle as long as you want jobs to be executed in the /// background! You can also use and await the handle like normal @@ -75,6 +97,7 @@ where db: Arc::new(self.db), error_handler: self.error_handler, context: Arc::new(self.context), + concurrency: self.concurrency, }; tokio::task::spawn(internal_runner.job_queue::()).into() } @@ -86,6 +109,7 @@ impl Debug for JobRunner { .field("db", &self.db) .field("error_handler", &"") .field("context", &self.context) + .field("concurrency", &self.concurrency) .finish() } } @@ -98,6 +122,9 @@ struct InternalJobRunner { error_handler: Option, /// Outside context type-map to provide resources to the jobs. context: Arc, + /// Concurrency limits, a range from minimum to maximum concurrent jobs to + /// be targeted in the execution queue. + concurrency: Range, } impl Clone for InternalJobRunner { @@ -106,6 +133,7 @@ impl Clone for InternalJobRunner { db: self.db.clone(), error_handler: self.error_handler.clone(), context: self.context.clone(), + concurrency: self.concurrency.clone(), } } } @@ -118,8 +146,14 @@ where async fn due_messages( &self, due_at: Timestamp, + limit: u32, ) -> Result, DueMessages>, BonsaiError> { - self.db.view::().with_key_range(..due_at).query_with_collection_docs().await + self.db + .view::() + .with_key_range(..due_at) + .limit(limit) + .query_with_collection_docs() + .await } /// Get the duration until the next message is due. @@ -158,14 +192,25 @@ where let subscriber = self.db.create_subscriber().await?; subscriber.subscribe_to(&MQ_NOTIFY).await?; + let currently_running = Arc::new(AtomicUsize::new(0)); loop { - // Retrieve due messages let now = OffsetDateTime::now_utc().unix_timestamp_nanos(); - let messages = self.due_messages(now).await?; - tracing::trace!("Found {} due messages.", messages.len()); + + // Retrieve due messages if there is not enough running already + let running = currently_running.load(Ordering::Relaxed) as u32; + #[allow(clippy::if_then_some_else_none)] // It is async. + let messages = if running < self.concurrency.start { + Some(self.due_messages(now, self.concurrency.end.saturating_sub(running)).await?) + } else { + None + }; + tracing::trace!( + "Handling {} due messages.", + messages.as_ref().map_or(0, MappedDocuments::len) + ); // Execute jobs for the messages - for msg in &messages { + for msg in messages.iter().flatten() { if let Some(job) = REG::from_name(&msg.document.contents.name) { // Filter out messages with active dependencies if let Some(dependency) = msg.document.contents.execute_after { @@ -187,8 +232,8 @@ where keep_alive: None, }; - // Dropping the handle to the running job.. Panics will not cause - let _jh = current_job.run(job.function()); + // Dropping the handle to the running job.. Panics will not cause anything. + let _jh = current_job.run(job.function(), currently_running.clone()); } } else { tracing::trace!( @@ -200,10 +245,13 @@ where // Sleep until the next message is due or a notification comes in. let next_due_in = self.next_message_due_in(now).await?; - tokio::time::timeout(next_due_in, subscriber.receiver().receive_async()) - .await - .ok() // Timeout is not a failure - .transpose()?; + tokio::time::timeout( + next_due_in.max(Duration::from_millis(100)), // Wait at least 100 ms. + subscriber.receiver().receive_async(), + ) + .await + .ok() // Timeout is not a failure + .transpose()?; } } } @@ -214,6 +262,7 @@ impl Debug for InternalJobRunner { .field("db", &self.db) .field("error_handler", &"") .field("context", &self.context) + .field("concurrency", &self.concurrency) .finish() } }