diff --git a/src/reduce.rs b/src/reduce.rs index 36018ce..fc7bf56 100644 --- a/src/reduce.rs +++ b/src/reduce.rs @@ -8,8 +8,9 @@ use tokio::sync::mpsc; use tokio::sync::mpsc::Sender; use tokio::task::JoinSet; use tokio_stream::wrappers::ReceiverStream; -use tonic::metadata::MetadataMap; use tonic::{async_trait, Request, Response, Status}; +use tonic::metadata::MetadataMap; + use crate::shared; const KEY_JOIN_DELIMITER: &str = ":"; @@ -25,8 +26,52 @@ pub mod proto { tonic::include_proto!("reduce.v1"); } -struct ReduceService { - handler: Arc, +struct ReduceService { + creator: C, +} + +/// `ReducerCreator` is a trait for creating a new instance of a `Reducer`. +pub trait ReducerCreator { + /// Each type that implements `ReducerCreator` must also specify an associated type `R` that implements the `Reducer` trait. + /// The `create` method is used to create a new instance of this `Reducer` type. + /// + /// # Example + /// + /// Below is an example of how to implement the `ReducerCreator` trait for a specific type `MyReducerCreator`. + /// `MyReducerCreator` creates instances of `MyReducer`, which is a type that implements the `Reducer` trait. + /// + /// ```rust + /// use numaflow::reduce::{Reducer, ReducerCreator, ReduceRequest, Metadata, Message}; + /// use tokio::sync::mpsc::Receiver; + /// use tonic::async_trait; + /// + /// pub struct MyReducer; + /// + /// #[async_trait] + /// impl Reducer for MyReducer { + /// async fn reduce( + /// &self, + /// keys: Vec, + /// mut input: Receiver, + /// md: &Metadata, + /// ) -> Vec { + /// // Implementation of the reduce method goes here. + /// vec![] + /// } + /// } + /// + /// pub struct MyReducerCreator; + /// + /// impl ReducerCreator for MyReducerCreator { + /// type R = MyReducer; + /// + /// fn create(&self) -> Self::R { + /// MyReducer + /// } + /// } + /// ``` + type R: Reducer + Send + Sync + 'static; + fn create(&self) -> Self::R; } /// Reducer trait for implementing Reduce handler. @@ -46,8 +91,8 @@ pub trait Reducer { /// /// #[tokio::main] /// async fn main() -> Result<(), Box> { - /// let reduce_handler = counter::Counter::new(); - /// reduce::Server::new(reduce_handler).start().await?; + /// let handler_creator = counter::CounterCreator{}; + /// reduce::Server::new(handler_creator).start().await?; /// Ok(()) /// } /// mod counter { @@ -57,6 +102,17 @@ pub trait Reducer { /// use tonic::async_trait; /// use numaflow::reduce::proto::reduce_server::Reduce; /// pub(crate) struct Counter {} + /// + /// pub(crate) struct CounterCreator {} + /// + /// impl numaflow::reduce::ReducerCreator for CounterCreator { + /// type R = Counter; + /// + /// fn create(&self) -> Counter { + /// Counter::new() + /// } + /// } + /// /// impl Counter { /// pub(crate) fn new() -> Self { /// Self {} @@ -68,7 +124,7 @@ pub trait Reducer { /// &self, /// keys: Vec, /// mut input: Receiver, - /// md: Metadata, + /// md: &Metadata, /// ) -> Vec { /// let mut counter = 0; /// // the loop exits when input is closed which will happen only on close of book. @@ -89,12 +145,11 @@ pub trait Reducer { &self, keys: Vec, input: mpsc::Receiver, - md: Metadata, + md: &Metadata, ) -> Vec; } /// IntervalWindow is the start and end boundary of the window. -#[derive(Clone)] pub struct IntervalWindow { // start time of the window pub start_time: DateTime, @@ -114,10 +169,9 @@ impl Metadata { } } -#[derive(Clone)] /// Metadata are additional information passed into the [`Reducer::reduce`]. pub struct Metadata { - pub interval_window: IntervalWindow + pub interval_window: IntervalWindow, } /// Message is the response from the user's [`Reducer::reduce`]. @@ -182,9 +236,9 @@ fn get_window_details(request: &MetadataMap) -> (DateTime, DateTime) { } #[async_trait] -impl proto::reduce_server::Reduce for ReduceService -where - T: Reducer + Send + Sync + 'static, +impl proto::reduce_server::Reduce for ReduceService + where + C: ReducerCreator + Send + Sync + 'static, { type ReduceFnStream = ReceiverStream>; async fn reduce_fn( @@ -193,7 +247,7 @@ where ) -> Result, Status> { // get gRPC window from metadata let (start_win, end_win) = get_window_details(request.metadata()); - let md = Metadata::new(IntervalWindow::new(start_win, end_win)); + let md = Arc::new(Metadata::new(IntervalWindow::new(start_win, end_win))); let mut key_to_tx: HashMap> = HashMap::new(); @@ -214,12 +268,12 @@ where // since we are calling this in a loop, we need make sure that there is reference counting // and the lifetime of self is more than the async function. // try Arc https://doc.rust-lang.org/reference/items/associated-items.html#methods ? - let v = Arc::clone(&self.handler); + let handler = self.creator.create(); + let m = Arc::clone(&md); // spawn task for each unique key let keys = rr.keys.clone(); - let reduce_md = md.clone(); - set.spawn(async move { v.reduce(keys, rx, reduce_md).await }); + set.spawn(async move { handler.reduce(keys, rx, m.as_ref()).await }); // write data into the channel tx.send(rr.into()).await.unwrap(); @@ -251,8 +305,8 @@ where tx.send(Ok(proto::ReduceResponse { results: datum_responses, })) - .await - .unwrap(); + .await + .unwrap(); } }); @@ -267,21 +321,21 @@ where /// gRPC server to start a reduce service #[derive(Debug)] -pub struct Server { +pub struct Server { sock_addr: PathBuf, max_message_size: usize, server_info_file: PathBuf, - svc: Option, + creator: Option, } -impl Server { +impl Server { /// Create a new Server with the given reduce service - pub fn new(reduce_svc: T) -> Self { + pub fn new(creator: C) -> Self { Server { sock_addr: DEFAULT_SOCK_ADDR.into(), max_message_size: DEFAULT_MAX_MESSAGE_SIZE, server_info_file: DEFAULT_SERVER_INFO_FILE.into(), - svc: Some(reduce_svc), + creator: Some(creator), } } @@ -325,12 +379,12 @@ impl Server { shutdown: F, ) -> Result<(), Box> where - T: Reducer + Send + Sync + 'static, - F: Future, + F: Future, + C: ReducerCreator + Send + Sync + 'static, { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; - let handler = Arc::new(self.svc.take().unwrap()); - let reduce_svc = ReduceService { handler }; + let creator = self.creator.take().unwrap(); + let reduce_svc = ReduceService { creator }; let reduce_svc = proto::reduce_server::ReduceServer::new(reduce_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); @@ -343,9 +397,9 @@ impl Server { } /// Starts the gRPC server. Automatically registers signal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the signal arrives. - pub async fn start(&mut self) -> Result<(), Box> + pub async fn start(&mut self) -> Result<(), Box> where - T: Reducer + Send + Sync + 'static, + C: ReducerCreator + Send + Sync + 'static, { self.start_with_shutdown(shared::shutdown_signal()).await }