diff --git a/.rustfmt.toml b/.rustfmt.toml index 3a2051d..51ed464 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,2 +1,2 @@ edition = "2021" -indent_style = "Block" +indent_style = "Block" \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 9f00866..16a7d63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,10 +8,11 @@ name = "numaflow" path = "src/lib.rs" [dependencies] -tonic = "0.10.2" +tonic = "0.11.0" prost = "0.12.3" prost-types = "0.12.3" tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "signal"] } +tokio-util = "0.7.10" tokio-stream = { version = "0.1.14", features = ["net"] } serde = { version = "1.0.194", features = ["derive"] } chrono = "0.4.31" @@ -19,9 +20,10 @@ serde_json = "1.0.111" futures-util = "0.3.30" tracing = "0.1.40" uuid = { version = "1.8.0", features = ["v4"] } +thiserror = "1.0" [build-dependencies] -tonic-build = "0.10.2" +tonic-build = "0.11.0" [dev-dependencies] tempfile = "3.9.0" diff --git a/examples/map-cat/src/main.rs b/examples/map-cat/src/main.rs index 4eefaf4..ab9c5e2 100644 --- a/examples/map-cat/src/main.rs +++ b/examples/map-cat/src/main.rs @@ -1,5 +1,4 @@ use numaflow::map; -use std::error::Error; #[tokio::main] async fn main() -> Result<(), Box> { @@ -11,9 +10,7 @@ struct Cat; #[tonic::async_trait] impl map::Mapper for Cat { async fn map(&self, input: map::MapRequest) -> Vec { - let message = map::Message::new(input.value) - .keys(input.keys) - .tags(vec![]); + let message = map::Message::new(input.value).keys(input.keys).tags(vec![]); vec![message] } } diff --git a/examples/map-tickgen-serde/src/main.rs b/examples/map-tickgen-serde/src/main.rs index 645d215..fecb8cf 100644 --- a/examples/map-tickgen-serde/src/main.rs +++ b/examples/map-tickgen-serde/src/main.rs @@ -1,14 +1,13 @@ +use chrono::{SecondsFormat, TimeZone, Utc}; use numaflow::map; use numaflow::map::Message; +use serde::Serialize; #[tokio::main] async fn main() -> Result<(), Box> { map::Server::new(TickGen).start().await } -use chrono::{SecondsFormat, TimeZone, Utc}; -use serde::Serialize; - struct TickGen; #[derive(serde::Deserialize)] diff --git a/examples/reduce-counter/Cargo.toml b/examples/reduce-counter/Cargo.toml index 1fba368..0569082 100644 --- a/examples/reduce-counter/Cargo.toml +++ b/examples/reduce-counter/Cargo.toml @@ -8,6 +8,6 @@ name = "server" path = "src/main.rs" [dependencies] -tonic = "0.9" +tonic = "0.11.0" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } -numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", branch="reduce" } +numaflow-rs = { path = "../../" } diff --git a/examples/reduce-counter/src/main.rs b/examples/reduce-counter/src/main.rs index 7e79882..77fa50b 100644 --- a/examples/reduce-counter/src/main.rs +++ b/examples/reduce-counter/src/main.rs @@ -1,22 +1,30 @@ -use numaflow::reduce::start_uds_server; +use numaflow::reduce; #[tokio::main] -async fn main() -> Result<(), Box> { - let reduce_handler = counter::Counter::new(); - - start_uds_server(reduce_handler).await?; - +async fn main() -> Result<(), Box> { + let handler_creator = counter::CounterCreator {}; + reduce::Server::new(handler_creator).start().await?; Ok(()) } mod counter { - use numaflow::reduce::{Datum, Message}; - use numaflow::reduce::{Metadata, Reducer}; + use numaflow::reduce::{Message, ReduceRequest}; + use numaflow::reduce::{Reducer, Metadata}; use tokio::sync::mpsc::Receiver; use tonic::async_trait; pub(crate) struct Counter {} + pub(crate) struct CounterCreator {} + + impl numaflow::reduce::ReducerCreator for CounterCreator { + type R = Counter; + + fn create(&self) -> Self::R { + Counter::new() + } + } + impl Counter { pub(crate) fn new() -> Self { Self {} @@ -25,33 +33,19 @@ mod counter { #[async_trait] impl Reducer for Counter { - async fn reduce( + async fn reduce( &self, keys: Vec, - mut input: Receiver, - md: &U, + mut input: Receiver, + md: &Metadata, ) -> Vec { - println!( - "Entering into UDF {:?} {:?}", - md.start_time(), - md.end_time() - ); - let mut counter = 0; // the loop exits when input is closed which will happen only on close of book. - while (input.recv().await).is_some() { + while input.recv().await.is_some() { counter += 1; } - - println!( - "Returning from UDF {:?} {:?}", - md.start_time(), - md.end_time() - ); - let message = reduce::Message::new(counter.to_string().into_bytes()) - .keys(keys.clone()) - .tags(vec![]); + let message = Message::new(counter.to_string().into_bytes()).tags(vec![]).keys(keys.clone()); vec![message] } } -} +} \ No newline at end of file diff --git a/examples/side-input/manifests/simple-sideinput.yaml b/examples/side-input/manifests/simple-sideinput.yaml index a893140..6fb7008 100644 --- a/examples/side-input/manifests/simple-sideinput.yaml +++ b/examples/side-input/manifests/simple-sideinput.yaml @@ -32,7 +32,7 @@ spec: - name: out sink: # A simple log printing sink - log: {} + log: { } edges: - from: in to: si-log diff --git a/examples/side-input/src/main.rs b/examples/side-input/src/main.rs index 2d16c6e..7f842b7 100644 --- a/examples/side-input/src/main.rs +++ b/examples/side-input/src/main.rs @@ -1,10 +1,8 @@ -use std::time::{SystemTime, UNIX_EPOCH}; -use numaflow::sideinput::start_uds_server; -use numaflow::sideinput::SideInputer; -use tonic::{async_trait}; use std::sync::Mutex; - +use numaflow::sideinput::SideInputer; +use numaflow::sideinput::start_uds_server; +use tonic::async_trait; struct SideInputHandler { counter: Mutex, @@ -20,8 +18,7 @@ impl SideInputHandler { #[async_trait] impl SideInputer for SideInputHandler { - - async fn retrieve_sideinput(& self) -> Option> { + async fn retrieve_sideinput(&self) -> Option> { let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("Time went backwards"); @@ -42,4 +39,4 @@ async fn main() -> Result<(), Box> { let side_input_handler = SideInputHandler::new(); start_uds_server(side_input_handler).await?; Ok(()) -} \ No newline at end of file +} diff --git a/examples/sideinput-udf/src/main.rs b/examples/sideinput-udf/src/main.rs index 91d7969..b823388 100644 --- a/examples/sideinput-udf/src/main.rs +++ b/examples/sideinput-udf/src/main.rs @@ -1,6 +1,7 @@ -use notify::{RecursiveMode, Result, Watcher}; -use numaflow::map::{MapRequest, Mapper, Message, Server}; use std::path::Path; + +use notify::{RecursiveMode, Result, Watcher}; +use numaflow::map::{Mapper, MapRequest, Message, Server}; use tokio::spawn; use tonic::async_trait; diff --git a/examples/simple-source/manifests/simple-source.yaml b/examples/simple-source/manifests/simple-source.yaml index 33923d4..c2006e7 100644 --- a/examples/simple-source/manifests/simple-source.yaml +++ b/examples/simple-source/manifests/simple-source.yaml @@ -16,7 +16,7 @@ spec: scale: min: 1 sink: - log: {} + log: { } edges: - from: in to: out \ No newline at end of file diff --git a/examples/simple-source/src/main.rs b/examples/simple-source/src/main.rs index 37dc873..b05cd10 100644 --- a/examples/simple-source/src/main.rs +++ b/examples/simple-source/src/main.rs @@ -10,16 +10,16 @@ pub(crate) mod simple_source { use numaflow::source::{Message, Offset, SourceReadRequest, Sourcer}; use std::collections::HashMap; use std::sync::Arc; + use std::sync::Arc; use std::{ - collections::HashSet, collections::HashMap, + collections::HashSet, sync::atomic::{AtomicUsize, Ordering}, sync::RwLock, }; use tokio::{sync::mpsc::Sender, time::Instant}; use tonic::async_trait; use uuid::Uuid; - use std::sync::Arc; /// SimpleSource is a data generator which generates monotonically increasing offsets and data. It is a shared state which is protected using Locks /// or Atomics to provide concurrent access. Numaflow actually does not require concurrent access but we are forced to do this because the SDK @@ -55,7 +55,6 @@ pub(crate) mod simple_source { let mut headers = HashMap::new(); headers.insert(String::from("x-txn-id"), String::from(Uuid::new_v4())); - // increment the read_idx which is used as the offset self.read_idx .store(self.read_idx.load(Ordering::Relaxed) + 1, Ordering::Relaxed); diff --git a/examples/sink-log/src/main.rs b/examples/sink-log/src/main.rs index 302fa95..53dbcba 100644 --- a/examples/sink-log/src/main.rs +++ b/examples/sink-log/src/main.rs @@ -1,5 +1,4 @@ use numaflow::sink::{self, Response, SinkRequest}; -use std::error::Error; #[tokio::main] async fn main() -> Result<(), Box> { diff --git a/examples/source-transformer-now/src/main.rs b/examples/source-transformer-now/src/main.rs index 5c4b7ea..b42454f 100644 --- a/examples/source-transformer-now/src/main.rs +++ b/examples/source-transformer-now/src/main.rs @@ -1,5 +1,4 @@ use numaflow::sourcetransform; -use std::error::Error; /// A simple source transformer which assigns event time to the current time in utc. diff --git a/proto/reduce.proto b/proto/reduce.proto index a789f97..1e21390 100644 --- a/proto/reduce.proto +++ b/proto/reduce.proto @@ -18,23 +18,58 @@ service Reduce { * ReduceRequest represents a request element. */ message ReduceRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + // WindowOperation represents a window operation. + // For Aligned windows, OPEN, APPEND and CLOSE events are sent. + message WindowOperation { + enum Event { + OPEN = 0; + CLOSE = 1; + APPEND = 4; + } + + Event event = 1; + repeated Window windows = 2; + } + + // Payload represents a payload element. + message Payload { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + } + + Payload payload = 1; + WindowOperation operation = 2; +} + +// Window represents a window. +// Since the client doesn't track keys, window doesn't have a keys field. +message Window { + google.protobuf.Timestamp start = 1; + google.protobuf.Timestamp end = 2; + string slot = 3; } /** * ReduceResponse represents a response element. */ message ReduceResponse { + // Result represents a result element. It contains the result of the reduce function. message Result { repeated string keys = 1; bytes value = 2; repeated string tags = 3; } - repeated Result results = 1; + + Result result = 1; + + // window represents a window to which the result belongs. + Window window = 2; + + // EOF represents the end of the response for a window. + bool EOF = 3; } /** diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..c4bdb52 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,25 @@ +use thiserror::Error; + +#[derive(Error, Debug, Clone)] +pub enum ErrorKind { + #[error("User Defined Error: {0}")] + UserDefinedError(String), + + #[error("Internal Error: {0}")] + InternalError(String), +} + +#[derive(Error, Debug, Clone)] +pub enum Error { + #[error("Map Error - {0}")] + MapError(ErrorKind), + + #[error("Reduce Error - {0}")] + ReduceError(ErrorKind), + + #[error("Sink Error - {0}")] + SinkError(ErrorKind), + + #[error("Source Error - {0}")] + SourceError(ErrorKind), +} diff --git a/src/lib.rs b/src/lib.rs index 79960bd..1313682 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ //! [Numaflow]: https://numaflow.numaproj.io/ //! [Map]: https://numaflow.numaproj.io/user-guide/user-defined-functions/map/map/ //! [Reduce]: https://numaflow.numaproj.io/user-guide/user-defined-functions/reduce/reduce/ +//! [User Defined Sources]: https://numaflow.numaproj.io/user-guide/sources/user-defined-sources/ //! [User Defined Sinks]: https://numaflow.numaproj.io/user-guide/sinks/user-defined-sinks/ /// start up code @@ -30,3 +31,44 @@ pub mod sink; /// building [side input](https://numaflow.numaproj.io/user-guide/reference/side-inputs/) pub mod sideinput; + +// Error handling on Numaflow SDKs! +// +// Any non-recoverable error will cause the process to shutdown with a non-zero exit status. All errors are non-recoverable. +// If there are errors that are retriable, we (gRPC or Numaflow SDK) would have already retried it (hence not an error), that means, +// all errors raised by the SDK are non-recoverable. +// +// Task Ordering and error propagation. +// +// level-1 level-2 level-3 +// +// +---> (service_fn) -> +// | +// | +// | +---> (task) +// | | +// | | +// (gRPC Service) ---+---> (service_fn) ---+---> (task) +// ^ | | +// | | | +// | | +---> (task) +// | | +// (shutdown) | +// | +---> (service_fn) -> +// | +// | +// (user) +// +// If a task at level-3 has an error, then that error will be propagated to level-2 (service_fn) via an mpsc::channel using the response channel. +// The Response channel passes a Result type and by returning Err() in response channel, it notifies top service_fn that the task wants to abort itself. +// service_fn (level-2) will now use another mpsc::channel to tell the gRPC service to cancel all the service_fns. gRPC service will +// will ask all the level-2 service_fns to abort using the CancellationToken. service_fn will call abort on all the tasks it created using internal +// mpsc::channel when CancellationToken has been dropped/cancelled. +// +// User can directly send shutdown request to the gRPC server which inturn cancels the CancellationToken. +// +// The above 3 level task ordering is only for complex cases like reduce, but for simpler endpoints like `map`, it only has 2 levels but +// the error propagation is handled the same way. + +/// error module +pub mod error; diff --git a/src/map.rs b/src/map.rs index af43009..eb9fa42 100644 --- a/src/map.rs +++ b/src/map.rs @@ -1,17 +1,20 @@ use std::collections::HashMap; -use std::future::Future; +use std::fs; use std::path::PathBuf; use chrono::{DateTime, Utc}; -use serde_json::Value; +use tokio::sync::{mpsc, oneshot}; +use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; use crate::shared; +use crate::shared::shutdown_signal; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/map.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/mapper-server-info"; const DROP: &str = "U+005C__DROP__"; + /// Numaflow Map Proto definitions. pub mod proto { tonic::include_proto!("map.v1"); @@ -19,6 +22,9 @@ pub mod proto { struct MapService { handler: T, + // not used ATM + // PLEASE WRITE WHY + _shutdown_tx: mpsc::Sender<()>, } /// Mapper trait for implementing Map handler. @@ -88,6 +94,7 @@ pub struct Message { /// Tags are used for [conditional forwarding](https://numaflow.numaproj.io/user-guide/reference/conditional-forwarding/). pub tags: Option>, } + /// Represents a message that can be modified and forwarded. impl Message { /// Creates a new message with the specified value. @@ -269,26 +276,39 @@ impl Server { } /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated. - pub async fn start_with_shutdown( + pub async fn start_with_shutdown( &mut self, - shutdown: F, + shutdown_rx: oneshot::Receiver<()>, ) -> Result<(), Box> where T: Mapper + Send + Sync + 'static, - F: Future, { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); - let map_svc = MapService { handler }; + + // Create a channel to send shutdown signal to the server to do graceful shutdown in case of non retryable errors. + let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + let map_svc = MapService { + handler, + _shutdown_tx: internal_shutdown_tx, + }; + let map_svc = proto::map_server::MapServer::new(map_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); + let shutdown = shutdown_signal( + internal_shutdown_rx, + Some(shutdown_rx), + CancellationToken::new(), + ); + tonic::transport::Server::builder() .add_service(map_svc) .serve_with_incoming_shutdown(listener, shutdown) - .await - .map_err(Into::into) + .await?; + + Ok(()) } /// Starts the gRPC server. Automatically registers signal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the signal arrives. @@ -296,7 +316,16 @@ impl Server { where T: Mapper + Send + Sync + 'static, { - self.start_with_shutdown(shared::shutdown_signal()).await + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + self.start_with_shutdown(shutdown_rx).await + } +} + +impl Drop for Server { + // Cleanup the socket file when the server is dropped so that when the server is restarted, it can bind to the + // same address. UnixListener doesn't implement Drop trait, so we have to manually remove the socket file. + fn drop(&mut self) { + let _ = fs::remove_file(&self.sock_addr); } } @@ -304,8 +333,8 @@ impl Server { mod tests { use crate::map; use crate::map::proto::map_client::MapClient; - use crate::map::Message; use std::{error::Error, time::Duration}; + use tempfile::TempDir; use tokio::sync::oneshot; use tonic::transport::Uri; @@ -339,10 +368,7 @@ mod tests { assert_eq!(server.socket_file(), sock_file); let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let shutdown = async { - shutdown_rx.await.unwrap(); - }; - let task = tokio::spawn(async move { server.start_with_shutdown(shutdown).await }); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); tokio::time::sleep(Duration::from_millis(50)).await; diff --git a/src/reduce.rs b/src/reduce.rs index 1eb21e4..808b212 100644 --- a/src/reduce.rs +++ b/src/reduce.rs @@ -1,22 +1,23 @@ use std::collections::HashMap; -use std::future::Future; +use std::fs; use std::path::PathBuf; use std::sync::Arc; -use chrono::{DateTime, TimeZone, Utc}; -use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; -use tokio::task::JoinSet; +use chrono::{DateTime, Utc}; +use tokio::sync::mpsc::{channel, Sender}; +use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; -use tonic::metadata::MetadataMap; +use tokio_stream::StreamExt; +use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; +use crate::error::Error; +use crate::error::Error::ReduceError; +use crate::error::ErrorKind::{InternalError, UserDefinedError}; use crate::shared; +use crate::shared::prost_timestamp_from_utc; const KEY_JOIN_DELIMITER: &str = ":"; -// grpc window metadata -const WIN_START_TIME: &str = "x-numaflow-win-start-time"; -const WIN_END_TIME: &str = "x-numaflow-win-end-time"; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/reduce.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/reducer-server-info"; @@ -28,7 +29,9 @@ pub mod proto { } struct ReduceService { - creator: C, + creator: Arc, + shutdown_tx: Sender<()>, + cancellation_token: CancellationToken, // used to cancel all the tasks } /// `ReducerCreator` is a trait for creating a new instance of a `Reducer`. @@ -129,7 +132,7 @@ pub trait Reducer { /// ) -> Vec { /// let mut counter = 0; /// // the loop exits when input is closed which will happen only on close of book. - /// while (input.recv().await).is_some() { + /// while input.recv().await.is_some() { /// counter += 1; /// } /// let message=Message::new(counter.to_string().into_bytes()).tags(vec![]).keys(keys.clone()); @@ -148,6 +151,7 @@ pub trait Reducer { } /// IntervalWindow is the start and end boundary of the window. +#[derive(Default, Clone, Debug)] pub struct IntervalWindow { // start time of the window pub start_time: DateTime, @@ -170,6 +174,7 @@ impl Metadata { } } +#[derive(Debug)] /// Metadata are additional information passed into the [`Reducer::reduce`]. pub struct Metadata { pub interval_window: IntervalWindow, @@ -293,45 +298,6 @@ pub struct ReduceRequest { pub headers: HashMap, } -impl From for ReduceRequest { - fn from(mr: proto::ReduceRequest) -> Self { - Self { - keys: mr.keys, - value: mr.value, - watermark: shared::utc_from_timestamp(mr.watermark), - eventtime: shared::utc_from_timestamp(mr.event_time), - headers: mr.headers, - } - } -} - -// extract start and end time from the gRPC MetadataMap -// https://youtu.be/s5S2Ed5T-dc?t=662 -fn get_window_details(request: &MetadataMap) -> (DateTime, DateTime) { - let (st, et) = ( - request - .get(WIN_START_TIME) - .unwrap_or_else(|| panic!("expected key {}", WIN_START_TIME)) - .to_str() - .unwrap() - .to_string() - .parse::() - .unwrap(), - request - .get(WIN_END_TIME) - .unwrap_or_else(|| panic!("expected key {}", WIN_END_TIME)) - .to_str() - .unwrap() - .parse::() - .unwrap(), - ); - - ( - Utc.timestamp_millis_opt(st).unwrap(), - Utc.timestamp_millis_opt(et).unwrap(), - ) -} - #[async_trait] impl proto::reduce_server::Reduce for ReduceService where @@ -342,77 +308,452 @@ where &self, request: Request>, ) -> Result, Status> { - // get gRPC window from metadata - let (start_win, end_win) = get_window_details(request.metadata()); - let md = Arc::new(Metadata::new(IntervalWindow::new(start_win, end_win))); + // Clone the creator and shutdown_tx to be used in the spawned tasks. + let creator = Arc::clone(&self.creator); + let shutdown_tx = self.shutdown_tx.clone(); - let mut key_to_tx: HashMap> = HashMap::new(); + // Create a channel to send the response back to the grpc client. + let (grpc_response_tx, grpc_response_rx) = + channel::>(1); - // we will be creating a set of tasks for this stream - let mut set = JoinSet::new(); + // Internal response channel which will be used by the task set and tasks to send the response after + // executing the user defined function. It's a result type so in case of error, we can send the error + // back to the client. + // + // NOTE: we are using a separate channel instead of the grpc_response_tx because in case of errors, + // we have to do graceful shutdown. + let (response_tx, mut response_rx) = channel::>(1); - let mut stream = request.into_inner(); + // Start a task executor to handle the incoming ReduceRequests from the client, returns a tx to send + // commands to the task executor and an oneshot tx to abort all the tasks. + let (task_tx, abort_tx) = TaskSet::start_task_executor(creator, response_tx.clone()); - while let Some(rr) = stream.message().await.unwrap() { - let task_name = rr.keys.join(KEY_JOIN_DELIMITER); + // Spawn a new task to listen to the response channel and send the response back to the grpc client. + // In case of error, it propagates the error back to the client in grpc status format and sends a shutdown + // signal to the grpc server. It also listens to the cancellation signal and aborts all the tasks. + let response_task_token = self.cancellation_token.clone(); + tokio::spawn(async move { + loop { + tokio::select! { + result = response_rx.recv() => { + match result { + Some(Ok(response)) => { + let eof = response.eof; + grpc_response_tx + .send(Ok(response)) + .await + .expect("send to grpc response channel failed"); + // all the tasks are done (COB has happened and we have closed the tx for the tasks) + if eof { + break; + } + } + Some(Err(error)) => { + grpc_response_tx + .send(Err(Status::internal(error.to_string()))) + .await + .expect("send to grpc response channel failed"); + // Send a shutdown signal to the grpc server. + shutdown_tx.send(()).await.expect("shutdown_tx send failed"); + } + None => { + // we break at eof, None should not happen + unreachable!() + } + } + } + _ = response_task_token.cancelled() => { + // Send an abort signal to the task executor to abort all the tasks. + abort_tx.send(()).expect("task_tx send failed"); + break; + } + } + } + }); - if let Some(tx) = key_to_tx.get(&task_name) { - tx.send(rr.into()).await.unwrap(); - } else { - // channel to send data to the user's reduce handle - let (tx, rx) = mpsc::channel::(1); + let request_cancel_token = self.cancellation_token.clone(); + + // Spawn a new task to handle the incoming ReduceRequests from the client + tokio::spawn(async move { + let mut stream = request.into_inner(); + loop { + tokio::select! { + reduce_request = stream.next() => { + match reduce_request { + Some(Ok(rr)) => { + task_tx + .send(TaskCommand::HandleReduceRequest(rr)) + .await + .expect("task_tx send failed"); + } + Some(Err(e)) => { + response_tx + .send(Err(ReduceError(InternalError(format!( + "Failed to receive request: {}", + e + ))))) + .await + .expect("error_tx send failed"); + } + // COB + None => break, + } + } + _ = request_cancel_token.cancelled() => { + // stop reading because server is shutting down + break; + } + } + } + task_tx + .send(TaskCommand::Close) + .await + .expect("task_tx send failed"); + }); + + // return the rx as the streaming endpoint + Ok(Response::new(ReceiverStream::new(grpc_response_rx))) + } + + async fn is_ready(&self, _: Request<()>) -> Result, Status> { + Ok(Response::new(proto::ReadyResponse { ready: true })) + } +} + +// The `Task` struct represents a task in the reduce service. It is responsible for executing the +// user defined function. We will a separate task for each keyed window. The task will be created +// when the first message for a given key arrives and will be closed when the window is closed. +struct Task { + udf_tx: Sender, + response_tx: Sender>, + done_rx: oneshot::Receiver<()>, + handle: tokio::task::JoinHandle<()>, +} + +// we only have one slot +const SLOT_0: &str = "slot-0"; + +impl Task { + // Creates a new task with the given reducer, keys, metadata, and response channel. + async fn new( + reducer: R, + keys: Vec, + md: Metadata, + response_tx: Sender>, + ) -> Self { + let (udf_tx, udf_rx) = channel::(1); + let (done_tx, done_rx) = oneshot::channel(); - // since we are calling this in a loop, we need make sure that there is reference counting - // and the lifetime of self is more than the async function. - // try Arc https://doc.rust-lang.org/reference/items/associated-items.html#methods ? - let handler = self.creator.create(); - let m = Arc::clone(&md); + let udf_response_tx = response_tx.clone(); + let task_join_handler = tokio::spawn(async move { + // execute user code + let messages = reducer.reduce(keys, udf_rx, &md).await; - // spawn task for each unique key - let keys = rr.keys.clone(); - set.spawn(async move { handler.reduce(keys, rx, m.as_ref()).await }); + // forward the responses + for message in messages { + let send_result = udf_response_tx + .send(Ok(proto::ReduceResponse { + result: Some(proto::reduce_response::Result { + keys: message.keys.unwrap_or_default(), + value: message.value, + tags: message.tags.unwrap_or_default(), + }), + window: Some(proto::Window { + start: prost_timestamp_from_utc(md.interval_window.start_time), + end: prost_timestamp_from_utc(md.interval_window.end_time), + slot: SLOT_0.to_string(), + }), + eof: false, + })) + .await; - // write data into the channel - tx.send(rr.into()).await.unwrap(); + if let Err(e) = send_result { + let _ = udf_response_tx + .send(Err(ReduceError(InternalError(format!( + "Failed to send response back: {}", + e + ))))) + .await; + return; + } + } + }); - // save the key and for future look up as long as the stream is active - key_to_tx.insert(task_name, tx); + // We spawn a separate task to await the join handler so that in case of any unhandled errors in the user-defined + // code will immediately be propagated to the client. + let handler_tx = response_tx.clone(); + let handle = tokio::spawn(async move { + if let Err(e) = task_join_handler.await { + let _ = handler_tx + .send(Err(ReduceError(UserDefinedError(format!(" {}", e))))) + .await; } + + // Send a message indicating that the task has finished + let _ = done_tx.send(()); + }); + + Self { + udf_tx, + response_tx, + done_rx, + // We store the task join handle so that we can abort the task if needed, we only need the second task handle because + // if the second task is aborted, the first task's handle will be dropped and the task will be aborted. + handle, + } + } + + // Sends the request to the user defined function's input channel. + async fn send(&self, rr: ReduceRequest) { + if let Err(e) = self.udf_tx.send(rr).await { + self.response_tx + .send(Err(ReduceError(InternalError(format!( + "Failed to send message to task: {}", + e + ))))) + .await + .expect("failed to send message to error channel"); } + } + + // Closes the task and waits for it to finish. + async fn close(self) { + // drop the sender to close the task + drop(self.udf_tx); + + // Wait for the task to finish + let _ = self.done_rx.await; + } + + // Aborts the task by calling abort on join handler. + async fn abort(self) { + self.handle.abort(); + } +} + +// The `TaskSet` struct represents a set of tasks that are executing the user defined function. It is responsible +// for creating new tasks, writing messages to the tasks, closing the tasks, and aborting the tasks. +struct TaskSet { + tasks: HashMap, + response_tx: Sender>, + creator: Arc, + window: IntervalWindow, +} - // close all the tx channels to tasks to close their corresponding rx - key_to_tx.clear(); +enum TaskCommand { + HandleReduceRequest(proto::ReduceRequest), + Close, +} - // channel to respond to numaflow main car as it expects streaming results. - let (tx, rx) = mpsc::channel::>(1); +impl TaskSet +where + C: ReducerCreator + Send + Sync + 'static, +{ + // Starts a new task executor which listens to incoming commands and executes them. + // returns a tx to send commands to the task executor and oneshot tx to abort all + // the tasks to gracefully shut down the task executor. + fn start_task_executor( + creator: Arc, + response_tx: Sender>, + ) -> (Sender, oneshot::Sender<()>) { + let (task_tx, mut task_rx) = channel::(1); + let (abort_tx, mut abort_rx) = oneshot::channel(); - // start the result streamer + let mut task_set = TaskSet { + tasks: HashMap::new(), + response_tx, + creator, + window: IntervalWindow::default(), + }; + + // Start a new task to listen to incoming commands and execute them, it will also listen to the abort signal. tokio::spawn(async move { - while let Some(res) = set.join_next().await { - let messages = res.unwrap(); - let mut datum_responses = vec![]; - for message in messages { - datum_responses.push(proto::reduce_response::Result { - keys: message.keys.unwrap_or_default(), - value: message.value, - tags: message.tags.unwrap_or_default(), - }); + loop { + tokio::select! { + cmd = task_rx.recv() => { + match cmd { + Some(TaskCommand::HandleReduceRequest(rr)) => { + // Extract the keys from the ReduceRequest. + let keys = match rr.payload.as_ref() { + Some(payload) => payload.keys.clone(), + None => { + task_set + .handle_error(ReduceError(InternalError( + "Invalid ReduceRequest".to_string(), + ))) + .await; + continue; + } + }; + + // Check if the task already exists, if it does, write the ReduceRequest to the task, + // otherwise create a new task and write the ReduceRequest to the task. + if task_set.tasks.contains_key(&keys.join(KEY_JOIN_DELIMITER)) { + task_set.write_to_task(keys, rr).await; + } else { + task_set.create_and_write(keys, rr).await; + } + } + Some(TaskCommand::Close) => task_set.close().await, + // COB + None => break, + } + } + _ = &mut abort_rx => { + task_set.abort().await; + break; + } } - // stream it out to the client - tx.send(Ok(proto::ReduceResponse { - results: datum_responses, - })) - .await - .unwrap(); } }); - // return the rx as the streaming endpoint - Ok(Response::new(ReceiverStream::new(rx))) + (task_tx, abort_tx) } - async fn is_ready(&self, _: Request<()>) -> Result, Status> { - Ok(Response::new(proto::ReadyResponse { ready: true })) + // Creates a new task with the given keys and `ReduceRequest`. + // It creates a new reducer and assigns it to the task to execute the user defined function. + async fn create_and_write(&mut self, keys: Vec, rr: proto::ReduceRequest) { + // validate + let (reduce_request, interval_window) = match self.validate_and_extract(rr).await { + Some(value) => value, + None => return, + }; + + self.window = interval_window.clone(); + + // Create a new reducer + let reducer = self.creator.create(); + + // Create Metadata with the extracted start and end time + let md = Metadata::new(interval_window); + + // Create a new Task with the reducer, keys, and metadata + let task = Task::new(reducer, keys.clone(), md, self.response_tx.clone()).await; + + // track the task in the task set + self.tasks.insert(keys.join(KEY_JOIN_DELIMITER), task); + + // send the request inside the proto payload to the task + // if the task does not exist, send an error to the stream + if let Some(task) = self.tasks.get(&keys.join(KEY_JOIN_DELIMITER)) { + task.send(reduce_request).await; + } else { + self.handle_error(ReduceError(InternalError("Task not found".to_string()))) + .await; + } + } + + // Writes the ReduceRequest to the task with the given keys. + async fn write_to_task(&mut self, keys: Vec, rr: proto::ReduceRequest) { + // validate the request + let (reduce_request, _) = match self.validate_and_extract(rr).await { + Some(value) => value, + None => return, + }; + + // Get the task name from the keys + let task_name = keys.join(KEY_JOIN_DELIMITER); + + // If the task exists, send the ReduceRequest to the task + if let Some(task) = self.tasks.get(&task_name) { + task.send(reduce_request).await; + } else { + self.handle_error(ReduceError(InternalError("Task not found".to_string()))) + .await; + } + } + + // Validates the ReduceRequest and extracts the payload and window information. + // If the ReduceRequest is invalid, it sends an error to the response stream and returns None. + async fn validate_and_extract( + &self, + rr: proto::ReduceRequest, + ) -> Option<(ReduceRequest, IntervalWindow)> { + // Extract the payload and window information from the ReduceRequest + let (payload, windows) = match (rr.payload, rr.operation) { + (Some(payload), Some(operation)) => (payload, operation.windows), + _ => { + self.handle_error(ReduceError(InternalError( + "Invalid ReduceRequest".to_string(), + ))) + .await; + return None; + } + }; + + // Check if there is exactly one window in the ReduceRequest + if windows.len() != 1 { + self.handle_error(ReduceError(InternalError( + "Exactly one window is required".to_string(), + ))) + .await; + return None; + } + + // Extract the start and end time from the window + let window = &windows[0]; + let (start_time, end_time) = ( + shared::utc_from_timestamp(window.start.clone()), + shared::utc_from_timestamp(window.end.clone()), + ); + + // Create the IntervalWindow + let interval_window = IntervalWindow::new(start_time, end_time); + + // Create the ReduceRequest + let reduce_request = ReduceRequest { + keys: payload.keys, + value: payload.value, + watermark: shared::utc_from_timestamp(payload.watermark), + eventtime: shared::utc_from_timestamp(payload.event_time), + headers: payload.headers, + }; + + Some((reduce_request, interval_window)) + } + + // Closes all tasks in the task set and sends an EOF message to the response stream. + async fn close(&mut self) { + for (_, task) in self.tasks.drain() { + task.close().await; + } + + // after all the tasks have been closed, send an EOF message to the response stream + let send_eof = self + .response_tx + .send(Ok(proto::ReduceResponse { + result: None, + window: Some(proto::Window { + start: prost_timestamp_from_utc(self.window.start_time), + end: prost_timestamp_from_utc(self.window.end_time), + slot: "slot-0".to_string(), + }), + eof: true, + })) + .await; + + if let Err(e) = send_eof { + self.handle_error(ReduceError(InternalError(format!( + "Failed to send EOF message: {}", + e + )))) + .await; + } + } + + // Aborts all tasks in the task set. + async fn abort(&mut self) { + for (_, task) in self.tasks.drain() { + task.abort().await; + } + } + + // Sends an error to the response stream. + async fn handle_error(&self, error: Error) { + self.response_tx + .send(Err(error)) + .await + .expect("error_tx send failed"); } } @@ -471,33 +812,456 @@ impl Server { } /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated. - pub async fn start_with_shutdown( + pub async fn start_with_shutdown( &mut self, - shutdown: F, + user_shutdown_rx: oneshot::Receiver<()>, ) -> Result<(), Box> where - F: Future, C: ReducerCreator + Send + Sync + 'static, { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let creator = self.creator.take().unwrap(); - let reduce_svc = ReduceService { creator }; + let (internal_shutdown_tx, internal_shutdown_rx) = channel(1); + let cln_token = CancellationToken::new(); + let reduce_svc = ReduceService { + creator: Arc::new(creator), + shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), + }; let reduce_svc = proto::reduce_server::ReduceServer::new(reduce_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); + let shutdown = + shared::shutdown_signal(internal_shutdown_rx, Some(user_shutdown_rx), cln_token); + tonic::transport::Server::builder() .add_service(reduce_svc) .serve_with_incoming_shutdown(listener, shutdown) - .await - .map_err(Into::into) + .await?; + + Ok(()) } - /// Starts the gRPC server. Automatically registers signal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the signal arrives. + /// Starts the gRPC server. Automatically registers signal handlers for SIGINT and SIGTERM and initiates + /// graceful shutdown of gRPC server when either one of the signal arrives. pub async fn start(&mut self) -> Result<(), Box> where C: ReducerCreator + Send + Sync + 'static, { - self.start_with_shutdown(shared::shutdown_signal()).await + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + self.start_with_shutdown(shutdown_rx).await + } +} + +impl Drop for Server { + // Cleanup the socket file when the server is dropped so that when the server is restarted, it can bind to the + // same address. UnixListener doesn't implement Drop trait, so we have to manually remove the socket file. + fn drop(&mut self) { + let _ = fs::remove_file(&self.sock_addr); + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + use std::{error::Error, time::Duration}; + + use prost_types::Timestamp; + use tempfile::TempDir; + use tokio::sync::{mpsc, oneshot}; + use tokio::time::sleep; + use tokio_stream::wrappers::ReceiverStream; + use tonic::transport::Uri; + use tonic::Request; + use tower::service_fn; + + use crate::reduce; + use crate::reduce::proto::reduce_client::ReduceClient; + + struct Sum; + + #[tonic::async_trait] + impl reduce::Reducer for Sum { + async fn reduce( + &self, + _keys: Vec, + mut input: mpsc::Receiver, + _md: &reduce::Metadata, + ) -> Vec { + let mut sum = 0; + while let Some(rr) = input.recv().await { + sum += std::str::from_utf8(&rr.value) + .unwrap() + .parse::() + .unwrap(); + } + vec![reduce::Message::new(sum.to_string().into_bytes())] + } + } + + struct SumCreator; + + impl reduce::ReducerCreator for SumCreator { + type R = Sum; + fn create(&self) -> Sum { + Sum {} + } + } + + async fn setup_server( + creator: C, + ) -> Result<(reduce::Server, PathBuf, PathBuf), Box> { + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("reduce.sock"); + let server_info_file = tmp_dir.path().join("reducer-server-info"); + + let server = reduce::Server::new(creator) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + Ok((server, sock_file, server_info_file)) + } + + async fn setup_client( + sock_file: PathBuf, + ) -> Result, Box> { + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + // Connect to an Uds socket + let sock_file = sock_file.clone(); + tokio::net::UnixStream::connect(sock_file) + })) + .await?; + + let client = ReduceClient::new(channel); + + Ok(client) + } + + #[tokio::test] + async fn test_server_start() -> Result<(), Box> { + let (mut server, sock_file, server_info_file) = setup_server(SumCreator).await?; + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // Check if the server has started + assert!(!task.is_finished(), "gRPC server should be running"); + + // Send shutdown signal + shutdown_tx + .send(()) + .expect("Sending shutdown signal to gRPC server"); + + // Check if the server has stopped within 100 ms + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + assert!(task.is_finished(), "gRPC server is still running"); + + Ok(()) + } + + #[tokio::test] + async fn valid_input() -> Result<(), Box> { + let (mut server, sock_file, _) = setup_server(SumCreator).await?; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let mut client = setup_client(sock_file).await?; + + let (tx, rx) = mpsc::channel(1); + + // Spawn a task to send ReduceRequests to the channel + tokio::spawn(async move { + let data = vec![("key1".to_string(), 1..=10), ("key2".to_string(), 1..=9)]; + + for (key, range) in data { + for i in range { + let rr = reduce::proto::ReduceRequest { + payload: Some(reduce::proto::reduce_request::Payload { + keys: vec![key.clone()], + value: i.to_string().as_bytes().to_vec(), + watermark: None, + event_time: None, + headers: Default::default(), + }), + operation: Some(reduce::proto::reduce_request::WindowOperation { + event: 0, + windows: vec![reduce::proto::Window { + start: Some(Timestamp { + seconds: 60000, + nanos: 0, + }), + end: Some(Timestamp { + seconds: 120000, + nanos: 0, + }), + slot: "slot-0".to_string(), + }], + }), + }; + + tx.send(rr).await.unwrap(); + } + } + }); + + // Convert the receiver end of the channel into a stream + let stream = ReceiverStream::new(rx); + + // Create a tonic::Request from the stream + let request = Request::new(stream); + + // Send the request to the server + let resp = client.reduce_fn(request).await?; + + let mut response_stream = resp.into_inner(); + let mut responses = Vec::new(); + + while let Some(response) = response_stream.message().await? { + responses.push(response); + } + + // since we are sending two different keys, we should get two responses + 1 EOF + assert_eq!(responses.len(), 3); + + for (i, response) in responses.iter().enumerate() { + if let Some(window) = response.window.as_ref() { + if let Some(start) = window.start.as_ref() { + assert_eq!(start.seconds, 60000); + } + if let Some(end) = window.end.as_ref() { + assert_eq!(end.seconds, 120000); + } + } + + if let Some(result) = response.result.as_ref() { + if result.keys == vec!["key1".to_string()] { + assert_eq!(result.value, 55.to_string().into_bytes()); + } else if result.keys == vec!["key2".to_string()] { + assert_eq!(result.value, 45.to_string().into_bytes()); + } + } + + // Check if this is the last message in the stream + // The last message should have eof set to true + if i == responses.len() - 1 { + assert!(response.eof); + } else { + assert!(!response.eof); + } + } + + shutdown_tx + .send(()) + .expect("Sending shutdown signal to gRPC server"); + + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + assert!(task.is_finished(), "gRPC server is still running"); + + Ok(()) + } + + #[tokio::test] + async fn invalid_input() -> Result<(), Box> { + let (mut server, sock_file, _) = setup_server(SumCreator).await?; + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let mut client = setup_client(sock_file).await?; + + let (tx, rx) = mpsc::unbounded_channel(); + + // Spawn a task to send ReduceRequests to the channel + let _sender_task = tokio::spawn(async move { + for _ in 0..10 { + let rr = reduce::proto::ReduceRequest { + payload: Some(reduce::proto::reduce_request::Payload { + keys: vec!["key1".to_string()], + value: vec![], + watermark: None, + event_time: None, + headers: Default::default(), + }), + operation: Some(reduce::proto::reduce_request::WindowOperation { + event: 0, + windows: vec![ + reduce::proto::Window { + start: Some(Timestamp { + seconds: 60000, + nanos: 0, + }), + end: Some(Timestamp { + seconds: 120000, + nanos: 0, + }), + slot: "slot-0".to_string(), + }, + reduce::proto::Window { + start: Some(Timestamp { + seconds: 60000, + nanos: 0, + }), + end: Some(Timestamp { + seconds: 120000, + nanos: 0, + }), + slot: "slot-0".to_string(), + }, + ], + }), + }; + + tx.send(rr).unwrap(); + sleep(Duration::from_millis(10)).await; + } + }); + + // Send the request to the server + let resp = client + .reduce_fn(Request::new( + tokio_stream::wrappers::UnboundedReceiverStream::new(rx), + )) + .await; + + let mut response_stream = resp.unwrap().into_inner(); + + if let Err(e) = response_stream.message().await { + assert_eq!(e.code(), tonic::Code::Internal); + assert!(e.message().contains("Exactly one window is required")); + } + + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } + + struct PanicReducer; + + #[tonic::async_trait] + impl reduce::Reducer for PanicReducer { + async fn reduce( + &self, + _keys: Vec, + _input: mpsc::Receiver, + _md: &reduce::Metadata, + ) -> Vec { + panic!("Panic in reduce method"); + } + } + + struct PanicReducerCreator; + + impl reduce::ReducerCreator for PanicReducerCreator { + type R = PanicReducer; + fn create(&self) -> PanicReducer { + PanicReducer {} + } + } + + #[tokio::test] + async fn panic_in_reduce() -> Result<(), Box> { + let (mut server, sock_file, _) = setup_server(PanicReducerCreator).await?; + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let mut client = setup_client(sock_file.clone()).await?; + + let (tx, rx) = mpsc::channel(1); + + // Spawn a task to send ReduceRequests to the channel + tokio::spawn(async move { + let rr = reduce::proto::ReduceRequest { + payload: Some(reduce::proto::reduce_request::Payload { + keys: vec!["key1".to_string()], + value: vec![], + watermark: None, + event_time: None, + headers: Default::default(), + }), + operation: Some(reduce::proto::reduce_request::WindowOperation { + event: 0, + windows: vec![reduce::proto::Window { + start: Some(Timestamp { + seconds: 60000, + nanos: 0, + }), + end: Some(Timestamp { + seconds: 120000, + nanos: 0, + }), + slot: "slot-0".to_string(), + }], + }), + }; + + for _ in 0..10 { + tx.send(rr.clone()).await.unwrap(); + sleep(Duration::from_millis(10)).await; + } + }); + + // Convert the receiver end of the channel into a stream + let stream = ReceiverStream::new(rx); + + // Create a tonic::Request from the stream + let request = Request::new(stream); + + // Send the request to the server + let resp = client.reduce_fn(request).await?; + + let mut response_stream = resp.into_inner(); + + if let Err(e) = response_stream.message().await { + assert_eq!(e.code(), tonic::Code::Internal); + assert!(e.message().contains("User Defined Error")) + } + + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + assert!(task.is_finished(), "gRPC server is still running"); + + Ok(()) } } diff --git a/src/shared.rs b/src/shared.rs index 8d978f7..9d4a88a 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -4,12 +4,15 @@ use std::{collections::HashMap, io}; use chrono::{DateTime, TimeZone, Timelike, Utc}; use prost_types::Timestamp; +use tokio::net::UnixListener; use tokio::signal; +use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::UnixListenerStream; +use tokio_util::sync::CancellationToken; use tracing::info; // #[tracing::instrument(skip(path), fields(path = ?path.as_ref()))] -#[tracing::instrument(fields(path = ?path.as_ref()))] +#[tracing::instrument(fields(path = ? path.as_ref()))] fn write_info_file(path: impl AsRef) -> io::Result<()> { let parent = path.as_ref().parent().unwrap(); std::fs::create_dir_all(parent)?; @@ -35,18 +38,13 @@ pub(crate) fn create_listener_stream( ) -> Result> { write_info_file(server_info_file).map_err(|e| format!("writing info file: {e:?}"))?; - let parent = socket_file.as_ref().parent().unwrap(); - std::fs::create_dir_all(parent).map_err(|e| format!("creating directory {parent:?}: {e:?}"))?; - - let uds = tokio::net::UnixListener::bind(socket_file)?; - Ok(tokio_stream::wrappers::UnixListenerStream::new(uds)) + let uds_stream = UnixListener::bind(socket_file)?; + Ok(UnixListenerStream::new(uds_stream)) } pub(crate) fn utc_from_timestamp(t: Option) -> DateTime { t.map_or(Utc.timestamp_nanos(-1), |t| { - DateTime::from_timestamp(t.seconds, t.nanos as u32).unwrap_or(Utc.timestamp_nanos(-1)) - }) } @@ -57,7 +55,20 @@ pub(crate) fn prost_timestamp_from_utc(t: DateTime) -> Option { }) } -pub(crate) async fn shutdown_signal() { +/// shuts downs the gRPC server. This happens in 2 cases +/// 1. there has been an internal error (one of the tasks failed) and we need to shutdown +/// 2. user is explictly asking us to shutdown +/// Once the request for shutdown has be invoked, server will broadcast shutdown to all tasks +/// through the cancellation-token. +pub(crate) async fn shutdown_signal( + mut shutdown_on_err: mpsc::Receiver<()>, + shutdown_from_user: Option>, + cancel_token: CancellationToken, +) { + // will call cancel_token.cancel() when the function exits + // because of abort request, ctrl-c, or SIGTERM signal + let _drop_guard = cancel_token.drop_guard(); + let ctrl_c = async { signal::ctrl_c() .await @@ -66,20 +77,38 @@ pub(crate) async fn shutdown_signal() { let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install SITERM handler") + .expect("failed to install SIGTERM handler") .recv() .await; }; + + let shutdown_on_err_future = async { + shutdown_on_err.recv().await; + }; + + let shutdown_from_user_future = async { + if let Some(rx) = shutdown_from_user { + rx.await.ok(); + } + }; + tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, + _ = shutdown_on_err_future => {}, + _ = shutdown_from_user_future => {}, } } + #[cfg(test)] -mod tests{ +mod tests { use super::*; + use std::fs::File; + use std::io::Read; + use tempfile::NamedTempFile; + #[test] - fn test_utc_from_timestamp(){ + fn test_utc_from_timestamp() { let specific_date = Utc.with_ymd_and_hms(2022, 7, 2, 2, 0, 0).unwrap(); let timestamp = Timestamp { @@ -87,30 +116,32 @@ mod tests{ nanos: specific_date.timestamp_subsec_nanos() as i32, }; - let utc_ts=utc_from_timestamp(Some(timestamp)); - assert_eq!(utc_ts,specific_date) + let utc_ts = utc_from_timestamp(Some(timestamp)); + assert_eq!(utc_ts, specific_date) } + #[test] - fn test_utc_from_timestamp_epoch_0(){ + fn test_utc_from_timestamp_epoch_0() { let specific_date = Utc.timestamp_nanos(-1); - let utc_ts=utc_from_timestamp(None); - assert_eq!(utc_ts,specific_date) + let utc_ts = utc_from_timestamp(None); + assert_eq!(utc_ts, specific_date) } + #[test] - fn test_prost_timestamp_from_utc(){ + fn test_prost_timestamp_from_utc() { let specific_date = Utc.with_ymd_and_hms(2022, 7, 2, 2, 0, 0).unwrap(); let timestamp = Timestamp { seconds: specific_date.timestamp(), nanos: specific_date.timestamp_subsec_nanos() as i32, }; - let prost_ts=prost_timestamp_from_utc(specific_date); - assert_eq!(prost_ts,Some(timestamp)) + let prost_ts = prost_timestamp_from_utc(specific_date); + assert_eq!(prost_ts, Some(timestamp)) } #[test] fn test_prost_timestamp_from_utc_epoch_0() { - let specific_date = Utc.timestamp(0, 0); + let specific_date = Utc.timestamp_nanos(0); let timestamp = Timestamp { seconds: 0, nanos: 0, @@ -119,5 +150,51 @@ mod tests{ assert_eq!(prost_ts, Some(timestamp)); } + #[tokio::test] + async fn test_write_info_file() -> io::Result<()> { + // Create a temporary file + let temp_file = NamedTempFile::new()?; + + // Call write_info_file with the path of the temporary file + write_info_file(temp_file.path())?; + + // Open the file and read its contents + let mut file = File::open(temp_file.path())?; + let mut contents = String::new(); + file.read_to_string(&mut contents)?; + + // Check if the contents of the file are as expected + assert!(contents.contains(r#""protocol":"uds""#)); + assert!(contents.contains(r#""language":"rust""#)); + assert!(contents.contains(r#""version":"0.0.1""#)); + assert!(contents.contains(r#""metadata":{}"#)); + + Ok(()) + } + + #[tokio::test] + async fn test_shutdown_signal() { + // Create a channel to send shutdown signal + let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + let (_user_shutdown_tx, user_shutdown_rx) = oneshot::channel(); + + // Spawn a new task to call shutdown_signal + let shutdown_signal_task = tokio::spawn(async move { + shutdown_signal( + internal_shutdown_rx, + Some(user_shutdown_rx), + CancellationToken::new(), + ) + .await; + }); + + // Send a shutdown signal + internal_shutdown_tx.send(()).await.unwrap(); -} \ No newline at end of file + // Wait for the shutdown_signal function to finish + let result = shutdown_signal_task.await; + + // If we reach this point, it means that the shutdown_signal function has correctly handled the shutdown signal + assert!(result.is_ok()); + } +} diff --git a/src/sideinput.rs b/src/sideinput.rs index ef46886..b5d3bb6 100644 --- a/src/sideinput.rs +++ b/src/sideinput.rs @@ -1,8 +1,13 @@ -use crate::shared; -use std::future::Future; +use std::fs; use std::path::PathBuf; + +use tokio::sync::{mpsc, oneshot}; +use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; +use crate::shared; +use crate::shared::shutdown_signal; + const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sideinput.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sideinput-server-info"; @@ -13,6 +18,7 @@ mod proto { struct SideInputService { handler: T, + _shutdown_tx: mpsc::Sender<()>, } /// The `SideInputer` trait defines a method for retrieving side input data. @@ -163,26 +169,37 @@ impl Server { } /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated. - pub async fn start_with_shutdown( + pub async fn start_with_shutdown( &mut self, - shutdown: F, + shutdown_rx: oneshot::Receiver<()>, ) -> Result<(), Box> where T: SideInputer + Send + Sync + 'static, - F: Future, { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); - let sideinput_svc = SideInputService { handler }; + let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + + let sideinput_svc = SideInputService { + handler, + _shutdown_tx: internal_shutdown_tx, + }; let sideinput_svc = proto::side_input_server::SideInputServer::new(sideinput_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); + let shutdown = shutdown_signal( + internal_shutdown_rx, + Some(shutdown_rx), + CancellationToken::new(), + ); + tonic::transport::Server::builder() .add_service(sideinput_svc) .serve_with_incoming_shutdown(listener, shutdown) - .await - .map_err(Into::into) + .await?; + + Ok(()) } /// Starts the gRPC server. Automatically registers signal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the signal arrives. @@ -190,6 +207,15 @@ impl Server { where T: SideInputer + Send + Sync + 'static, { - self.start_with_shutdown(shared::shutdown_signal()).await + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + self.start_with_shutdown(shutdown_rx).await + } +} + +impl Drop for Server { + // Cleanup the socket file when the server is dropped so that when the server is restarted, it can bind to the + // same address. UnixListener doesn't implement Drop trait, so we have to manually remove the socket file. + fn drop(&mut self) { + let _ = fs::remove_file(&self.sock_addr); } } diff --git a/src/sink.rs b/src/sink.rs index 1109028..7870487 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; -use std::env; -use std::future::Future; use std::path::PathBuf; +use std::{env, fs}; use chrono::{DateTime, Utc}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; +use tokio_util::sync::CancellationToken; use tonic::{Request, Status, Streaming}; use crate::shared; @@ -24,7 +24,8 @@ pub mod proto { } struct SinkService { - pub handler: T, + handler: T, + _shutdown_tx: mpsc::Sender<()>, } /// Sinker trait for implementing user defined sinks. @@ -280,26 +281,38 @@ impl Server { } /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated. - pub async fn start_with_shutdown( + pub async fn start_with_shutdown( &mut self, - shutdown: F, + shutdown_rx: oneshot::Receiver<()>, ) -> Result<(), Box> where T: Sinker + Send + Sync + 'static, - F: Future, { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); - let svc = SinkService { handler }; + let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + + let svc = SinkService { + handler, + _shutdown_tx: internal_shutdown_tx, + }; + let svc = proto::sink_server::SinkServer::new(svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); + let shutdown = shared::shutdown_signal( + internal_shutdown_rx, + Some(shutdown_rx), + CancellationToken::new(), + ); + tonic::transport::Server::builder() .add_service(svc) .serve_with_incoming_shutdown(listener, shutdown) - .await - .map_err(Into::into) + .await?; + + Ok(()) } /// Starts the gRPC server. Automatically registers signal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the singal arrives. @@ -307,7 +320,16 @@ impl Server { where T: Sinker + Send + Sync + 'static, { - self.start_with_shutdown(shared::shutdown_signal()).await + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + self.start_with_shutdown(shutdown_rx).await + } +} + +impl Drop for Server { + // Cleanup the socket file when the server is dropped so that when the server is restarted, it can bind to the + // same address. UnixListener doesn't implement Drop trait, so we have to manually remove the socket file. + fn drop(&mut self) { + let _ = fs::remove_file(&self.sock_addr); } } @@ -372,10 +394,7 @@ mod tests { assert_eq!(server.socket_file(), sock_file); let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let shutdown = async { - shutdown_rx.await.unwrap(); - }; - let task = tokio::spawn(async move { server.start_with_shutdown(shutdown).await }); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); tokio::time::sleep(Duration::from_millis(50)).await; diff --git a/src/source.rs b/src/source.rs index 848c43c..d474eaa 100644 --- a/src/source.rs +++ b/src/source.rs @@ -1,17 +1,18 @@ -#![warn(missing_docs)] - use std::collections::HashMap; -use std::future::Future; +use std::fs; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; -use crate::shared::{self, prost_timestamp_from_utc}; use chrono::{DateTime, Utc}; use tokio::sync::mpsc::{self, Sender}; +use tokio::sync::oneshot; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; +use crate::shared::{self, prost_timestamp_from_utc}; + const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/source.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcer-server-info"; @@ -23,6 +24,7 @@ pub mod proto { struct SourceService { handler: Arc, + _shutdown_tx: Sender<()>, } #[async_trait] @@ -196,7 +198,7 @@ pub struct Message { pub event_time: DateTime, /// Keys of the message. pub keys: Vec, - + /// Headers of the message. pub headers: HashMap, } @@ -255,29 +257,38 @@ impl Server { } /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated. - pub async fn start_with_shutdown( + pub async fn start_with_shutdown( &mut self, - shutdown: F, + shutdown_rx: oneshot::Receiver<()>, ) -> Result<(), Box> where T: Sourcer + Send + Sync + 'static, - F: Future, { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); + let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + let source_service = SourceService { handler: Arc::new(handler), + _shutdown_tx: internal_shutdown_tx, }; let source_svc = proto::source_server::SourceServer::new(source_service) .max_decoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); + let shutdown = shared::shutdown_signal( + internal_shutdown_rx, + Some(shutdown_rx), + CancellationToken::new(), + ); + tonic::transport::Server::builder() .add_service(source_svc) .serve_with_incoming_shutdown(listener, shutdown) - .await - .map_err(Into::into) + .await?; + + Ok(()) } /// Starts the gRPC server. Automatically registers singal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the singal arrives. @@ -285,26 +296,34 @@ impl Server { where T: Sourcer + Send + Sync + 'static, { - self.start_with_shutdown(shared::shutdown_signal()).await + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + self.start_with_shutdown(shutdown_rx).await + } +} + +impl Drop for Server { + // Cleanup the socket file when the server is dropped so that when the server is restarted, it can bind to the + // same address. UnixListener doesn't implement Drop trait, so we have to manually remove the socket file. + fn drop(&mut self) { + let _ = fs::remove_file(&self.sock_addr); } } #[cfg(test)] mod tests { - use super::proto; + use super::{proto, Message, Offset, SourceReadRequest}; use chrono::Utc; use std::collections::{HashMap, HashSet}; - use std::sync::Arc; use std::vec; use std::{error::Error, time::Duration}; - use tokio_stream::StreamExt; - use tower::service_fn; - use crate::source::{self, Message, Offset, SourceReadRequest}; + use crate::source; use tempfile::TempDir; use tokio::sync::mpsc::Sender; use tokio::sync::oneshot; + use tokio_stream::StreamExt; use tonic::transport::Uri; + use tower::service_fn; use uuid::Uuid; // A source that repeats the `num` for the requested count @@ -388,10 +407,7 @@ mod tests { assert_eq!(server.socket_file(), sock_file); let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let shutdown = async { - shutdown_rx.await.unwrap(); - }; - let task = tokio::spawn(async move { server.start_with_shutdown(shutdown).await }); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); tokio::time::sleep(Duration::from_millis(50)).await; diff --git a/src/sourcetransform.rs b/src/sourcetransform.rs index 199ed3a..880520c 100644 --- a/src/sourcetransform.rs +++ b/src/sourcetransform.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; -use std::future::Future; +use std::fs; use std::path::PathBuf; use chrono::{DateTime, Utc}; +use tokio::sync::{mpsc, oneshot}; +use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; use crate::shared::{self, prost_timestamp_from_utc}; @@ -12,6 +14,7 @@ const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sourcetransform.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcetransformer-server-info"; const DROP: &str = "U+005C__DROP__"; + /// Numaflow SourceTransformer Proto definitions. pub mod proto { tonic::include_proto!("sourcetransformer.v1"); @@ -19,6 +22,7 @@ pub mod proto { struct SourceTransformerService { handler: T, + _shutdown_tx: mpsc::Sender<()>, } /// SourceTransformer trait for implementing SourceTransform handler. @@ -304,27 +308,38 @@ impl Server { } /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated. - pub async fn start_with_shutdown( + pub async fn start_with_shutdown( &mut self, - shutdown: F, + shutdown_rx: oneshot::Receiver<()>, ) -> Result<(), Box> where T: SourceTransformer + Send + Sync + 'static, - F: Future, { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); - let sourcetrf_svc = SourceTransformerService { handler }; + let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + + let sourcetrf_svc = SourceTransformerService { + handler, + _shutdown_tx: internal_shutdown_tx, + }; let sourcetrf_svc = proto::source_transform_server::SourceTransformServer::new(sourcetrf_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); + let shutdown = shared::shutdown_signal( + internal_shutdown_rx, + Some(shutdown_rx), + CancellationToken::new(), + ); + tonic::transport::Server::builder() .add_service(sourcetrf_svc) .serve_with_incoming_shutdown(listener, shutdown) - .await - .map_err(Into::into) + .await?; + + Ok(()) } /// Starts the gRPC server. Automatically registers singal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the singal arrives. @@ -332,20 +347,30 @@ impl Server { where T: SourceTransformer + Send + Sync + 'static, { - self.start_with_shutdown(shared::shutdown_signal()).await + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + self.start_with_shutdown(shutdown_rx).await + } +} + +impl Drop for Server { + // Cleanup the socket file when the server is dropped so that when the server is restarted, it can bind to the + // same address. UnixListener doesn't implement Drop trait, so we have to manually remove the socket file. + fn drop(&mut self) { + let _ = fs::remove_file(&self.sock_addr); } } #[cfg(test)] mod tests { use std::{error::Error, time::Duration}; - use tower::service_fn; - use crate::sourcetransform; - use crate::sourcetransform::proto::source_transform_client::SourceTransformClient; use tempfile::TempDir; use tokio::sync::oneshot; use tonic::transport::Uri; + use tower::service_fn; + + use crate::sourcetransform; + use crate::sourcetransform::proto::source_transform_client::SourceTransformClient; #[tokio::test] async fn sourcetransformer_server() -> Result<(), Box> { @@ -379,10 +404,7 @@ mod tests { assert_eq!(server.socket_file(), sock_file); let (shutdown_tx, shutdown_rx) = oneshot::channel(); - let shutdown = async { - shutdown_rx.await.unwrap(); - }; - let task = tokio::spawn(async move { server.start_with_shutdown(shutdown).await }); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); tokio::time::sleep(Duration::from_millis(50)).await;