From f02acf9c23a739cb8cc3d11d4236e8b58bf10583 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Mon, 29 Jul 2024 12:29:22 +0530 Subject: [PATCH] chore: handle panics inside user handlers (#67) Signed-off-by: Yashash H L --- src/error.rs | 6 ++ src/lib.rs | 11 +- src/map.rs | 230 +++++++++++++++++++++++++++++++++++---- src/reduce.rs | 240 +++++++++++++++++++++++++++++++++-------- src/shared.rs | 17 +-- src/sideinput.rs | 69 +++++++----- src/sink.rs | 192 +++++++++++++++++++++++++++------ src/source.rs | 15 +-- src/sourcetransform.rs | 138 ++++++++++++++++++++---- 9 files changed, 746 insertions(+), 172 deletions(-) diff --git a/src/error.rs b/src/error.rs index c4bdb52..d517799 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,4 +22,10 @@ pub enum Error { #[error("Source Error - {0}")] SourceError(ErrorKind), + + #[error("Source Transformer Error: {0}")] + SourceTransformerError(ErrorKind), + + #[error("SideInput Error: {0}")] + SideInputError(ErrorKind), } diff --git a/src/lib.rs b/src/lib.rs index 1313682..c7d7320 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ //! features. It will support all the core features eventually. It supports [Map], [Reduce], and //! [User Defined Sinks]. //! -//! Please note that the Rust SDK is experimental and will be refactor in the future to make it more +//! Please note that the Rust SDK is experimental and will be refactored in the future to make it more //! idiomatic. //! //! [Numaflow]: https://numaflow.numaproj.io/ @@ -35,7 +35,7 @@ pub mod sideinput; // Error handling on Numaflow SDKs! // // Any non-recoverable error will cause the process to shutdown with a non-zero exit status. All errors are non-recoverable. -// If there are errors that are retriable, we (gRPC or Numaflow SDK) would have already retried it (hence not an error), that means, +// If there are errors that are retryable, we (gRPC or Numaflow SDK) would have already retried it (hence not an error), that means, // all errors raised by the SDK are non-recoverable. // // Task Ordering and error propagation. @@ -59,13 +59,14 @@ pub mod sideinput; // | // (user) // -// If a task at level-3 has an error, then that error will be propagated to level-2 (service_fn) via an mpsc::channel using the response channel. +// If a task at level-3 has an error, then that error will be propagated to level-2 (service_fn) via a mpsc::channel using the response channel. // The Response channel passes a Result type and by returning Err() in response channel, it notifies top service_fn that the task wants to abort itself. // service_fn (level-2) will now use another mpsc::channel to tell the gRPC service to cancel all the service_fns. gRPC service will -// will ask all the level-2 service_fns to abort using the CancellationToken. service_fn will call abort on all the tasks it created using internal +// ask all the level-2 service_fns to abort using the CancellationToken. service_fn will call abort on all the tasks it created using internal // mpsc::channel when CancellationToken has been dropped/cancelled. // -// User can directly send shutdown request to the gRPC server which inturn cancels the CancellationToken. +// User can directly send shutdown request to the gRPC server which triggers the shutdown of the server by stop accepting new requests +// and draining the existing requests. Lastly we will cancel the cancellation token to make sure all the tasks are aborted. // // The above 3 level task ordering is only for complex cases like reduce, but for simpler endpoints like `map`, it only has 2 levels but // the error propagation is handled the same way. diff --git a/src/map.rs b/src/map.rs index e892e9c..56c5345 100644 --- a/src/map.rs +++ b/src/map.rs @@ -1,15 +1,16 @@ +use crate::error::Error::MapError; +use crate::error::ErrorKind::{InternalError, UserDefinedError}; +use crate::shared; +use crate::shared::shutdown_signal; +use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::fs; use std::path::PathBuf; - -use chrono::{DateTime, Utc}; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; -use crate::shared; -use crate::shared::shutdown_signal; - const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/map.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/mapper-server-info"; @@ -21,10 +22,9 @@ pub mod proto { } struct MapService { - handler: T, - // not used ATM - // PLEASE WRITE WHY - _shutdown_tx: mpsc::Sender<()>, + handler: Arc, + shutdown_tx: mpsc::Sender<()>, + cancellation_token: CancellationToken, } /// Mapper trait for implementing Map handler. @@ -71,11 +71,36 @@ where request: Request, ) -> Result, Status> { let request = request.into_inner(); - let result = self.handler.map(request.into()).await; + let handler = Arc::clone(&self.handler); + let handle = tokio::spawn(async move { handler.map(request.into()).await }); + let shutdown_tx = self.shutdown_tx.clone(); + let cancellation_token = self.cancellation_token.clone(); + + // Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled), + // then return an error. + tokio::select! { + result = handle => { + match result { + Ok(result) => Ok(Response::new(proto::MapResponse { + results: result.into_iter().map(|msg| msg.into()).collect(), + })), + Err(e) => { + tracing::error!("Error in map handler: {:?}", e); + // Send a shutdown signal to the server to do a graceful shutdown because there was + // a panic in the handler. + shutdown_tx + .send(()) + .await + .expect("Sending shutdown signal to gRPC server"); + Err(Status::internal(MapError(UserDefinedError(e.to_string())).to_string())) + } + } + }, - Ok(Response::new(proto::MapResponse { - results: result.into_iter().map(|msg| msg.into()).collect(), - })) + _ = cancellation_token.cancelled() => { + Err(Status::internal(MapError(InternalError("Server is shutting down".to_string())).to_string())) + }, + } } async fn is_ready(&self, _: Request<()>) -> Result, Status> { @@ -285,23 +310,24 @@ impl Server { { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); + let cln_token = CancellationToken::new(); // Create a channel to send shutdown signal to the server to do graceful shutdown in case of non retryable errors. let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); let map_svc = MapService { - handler, - _shutdown_tx: internal_shutdown_tx, + handler: Arc::new(handler), + shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), }; let map_svc = proto::map_server::MapServer::new(map_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); - let shutdown = shutdown_signal( - internal_shutdown_rx, - Some(shutdown_rx), - CancellationToken::new(), - ); + let shutdown = shutdown_signal(internal_shutdown_rx, Some(shutdown_rx)); + + // will call cancel_token.cancel() on drop of _drop_guard + let _drop_guard = cln_token.drop_guard(); tonic::transport::Server::builder() .add_service(map_svc) @@ -338,6 +364,7 @@ mod tests { use tempfile::TempDir; use tokio::net::UnixStream; use tokio::sync::oneshot; + use tokio::time::sleep; use tonic::transport::Uri; use tower::service_fn; @@ -409,4 +436,167 @@ mod tests { assert!(task.is_finished(), "gRPC server is still running"); Ok(()) } + + #[tokio::test] + async fn map_server_panic() -> Result<(), Box> { + struct PanicCat; + #[tonic::async_trait] + impl map::Mapper for PanicCat { + async fn map(&self, _input: map::MapRequest) -> Vec { + panic!("PanicCat panicking"); + } + } + + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("map.sock"); + let server_info_file = tmp_dir.path().join("mapper-server-info"); + + let mut server = map::Server::new(PanicCat) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + let sock_file = sock_file.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + UnixStream::connect(sock_file).await?, + )) + } + })) + .await?; + + let mut client = MapClient::new(channel); + let request = tonic::Request::new(map::proto::MapRequest { + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }); + + // server should return an error because of the panic. + let resp = client.map_fn(request).await; + assert!(resp.is_err(), "Expected error from server"); + + if let Err(e) = resp { + assert_eq!(e.code(), tonic::Code::Internal); + assert!(e.message().contains("User Defined Error")); + } + + // server should shut down gracefully because there was a panic in the handler. + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } + + // tests for panic when we have multiple inflight requests, only one of the requests + // causes panic, the other requests should be processed successfully and the server + // should shut down gracefully. + #[tokio::test] + async fn panic_with_multiple_requests() -> Result<(), Box> { + struct PanicCat; + #[tonic::async_trait] + impl map::Mapper for PanicCat { + async fn map(&self, input: map::MapRequest) -> Vec { + if !input.keys.is_empty() && input.keys[0] == "key1" { + sleep(Duration::from_millis(20)).await; + panic!("Cat panicked"); + } + // assume each request takes 100ms to process + sleep(Duration::from_millis(100)).await; + vec![] + } + } + + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("map.sock"); + let server_info_file = tmp_dir.path().join("mapper-server-info"); + + let mut server = map::Server::new(PanicCat) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + let sock_file = sock_file.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + UnixStream::connect(sock_file).await?, + )) + } + })) + .await?; + + let mut client = MapClient::new(channel); + + let mut client_one = client.clone(); + tokio::spawn(async move { + let request = tonic::Request::new(map::proto::MapRequest { + keys: vec!["key2".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }); + + // panic is only for requests with key "key1", since we have graceful shutdown + // the request should get processed. + let resp = client_one.map_fn(request).await; + assert!(resp.is_ok(), "Expected ok from server"); + }); + + let request = tonic::Request::new(map::proto::MapRequest { + keys: vec!["key1".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }); + + // panic happens for the key1 request, so we should expect error on the client side. + let resp = client.map_fn(request).await; + assert!(resp.is_err(), "Expected error from server"); + + if let Err(e) = resp { + assert_eq!(e.code(), tonic::Code::Internal); + assert!(e.message().contains("User Defined Error")); + } + + // but since there is a panic, the server should shutdown. + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } } diff --git a/src/reduce.rs b/src/reduce.rs index 2dc7760..554c8bd 100644 --- a/src/reduce.rs +++ b/src/reduce.rs @@ -328,6 +328,39 @@ where // commands to the task executor and an oneshot tx to abort all the tasks. let (task_tx, abort_tx) = TaskSet::start_task_executor(creator, response_tx.clone()); + // Spawn a new task to handle the incoming ReduceRequests from the client + let reader_handle = tokio::spawn(async move { + let mut stream = request.into_inner(); + loop { + match stream.next().await { + Some(Ok(rr)) => { + task_tx + .send(TaskCommand::HandleReduceRequest(rr)) + .await + .expect("task_tx send failed"); + } + Some(Err(e)) => { + response_tx + .send(Err(ReduceError(InternalError(format!( + "Failed to receive request: {}", + e + ))))) + .await + .expect("error_tx send failed"); + break; + } + // COB + None => { + task_tx + .send(TaskCommand::Close) + .await + .expect("task_tx send failed"); + break; + } + } + } + }); + // Spawn a new task to listen to the response channel and send the response back to the grpc client. // In case of error, it propagates the error back to the client in grpc status format and sends a shutdown // signal to the grpc server. It also listens to the cancellation signal and aborts all the tasks. @@ -349,10 +382,13 @@ where } } Some(Err(error)) => { + tracing::error!("Error from task: {:?}", error); grpc_response_tx .send(Err(Status::internal(error.to_string()))) .await .expect("send to grpc response channel failed"); + // stop reading new messages from the stream. + reader_handle.abort(); // Send a shutdown signal to the grpc server. shutdown_tx.send(()).await.expect("shutdown_tx send failed"); } @@ -363,6 +399,8 @@ where } } _ = response_task_token.cancelled() => { + // stop reading new messages from stream. + reader_handle.abort(); // Send an abort signal to the task executor to abort all the tasks. abort_tx.send(()).expect("task_tx send failed"); break; @@ -371,46 +409,6 @@ where } }); - let request_cancel_token = self.cancellation_token.clone(); - - // Spawn a new task to handle the incoming ReduceRequests from the client - tokio::spawn(async move { - let mut stream = request.into_inner(); - loop { - tokio::select! { - reduce_request = stream.next() => { - match reduce_request { - Some(Ok(rr)) => { - task_tx - .send(TaskCommand::HandleReduceRequest(rr)) - .await - .expect("task_tx send failed"); - } - Some(Err(e)) => { - response_tx - .send(Err(ReduceError(InternalError(format!( - "Failed to receive request: {}", - e - ))))) - .await - .expect("error_tx send failed"); - } - // COB - None => break, - } - } - _ = request_cancel_token.cancelled() => { - // stop reading because server is shutting down - break; - } - } - } - task_tx - .send(TaskCommand::Close) - .await - .expect("task_tx send failed"); - }); - // return the rx as the streaming endpoint Ok(Response::new(ReceiverStream::new(grpc_response_rx))) } @@ -693,8 +691,8 @@ where // Extract the start and end time from the window let window = &windows[0]; let (start_time, end_time) = ( - shared::utc_from_timestamp(window.start.clone()), - shared::utc_from_timestamp(window.end.clone()), + shared::utc_from_timestamp(window.start), + shared::utc_from_timestamp(window.end), ); // Create the IntervalWindow @@ -832,8 +830,10 @@ impl Server { .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); - let shutdown = - shared::shutdown_signal(internal_shutdown_rx, Some(user_shutdown_rx), cln_token); + let shutdown = shared::shutdown_signal(internal_shutdown_rx, Some(user_shutdown_rx)); + + // will call cancel_token.cancel() on drop of _drop_guard + let _drop_guard = cln_token.drop_guard(); tonic::transport::Server::builder() .add_service(reduce_svc) @@ -1269,4 +1269,156 @@ mod tests { Ok(()) } + + // test panic in reduce method when there are multiple inflight requests + // panic only happens for one of the requests, the other request should be + // processed successfully since we do graceful shutdown of the server. + #[tokio::test] + async fn panic_with_multiple_keys() -> Result<(), Box> { + struct PanicReducerCreator; + + impl reduce::ReducerCreator for PanicReducerCreator { + type R = PanicReducer; + fn create(&self) -> PanicReducer { + PanicReducer {} + } + } + + struct PanicReducer; + + #[tonic::async_trait] + impl reduce::Reducer for PanicReducer { + async fn reduce( + &self, + keys: Vec, + mut input: mpsc::Receiver, + _md: &reduce::Metadata, + ) -> Vec { + let mut count = 0; + while input.recv().await.is_some() { + count += 1; + if count == 10 && keys[0] == "key2" { + panic!("Panic in reduce method"); + } + } + vec![] + } + } + let (mut server, sock_file, _) = setup_server(PanicReducerCreator).await?; + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let client = setup_client(sock_file.clone()).await?; + + let (tx1, rx1) = mpsc::channel(1); + + let (tx2, rx2) = mpsc::channel(1); + + // Spawn a task to send ReduceRequests to the channel + tokio::spawn(async move { + let rr = reduce::proto::ReduceRequest { + payload: Some(reduce::proto::reduce_request::Payload { + keys: vec!["key1".to_string()], + value: vec![], + watermark: None, + event_time: None, + headers: Default::default(), + }), + operation: Some(reduce::proto::reduce_request::WindowOperation { + event: 0, + windows: vec![reduce::proto::Window { + start: Some(Timestamp { + seconds: 60000, + nanos: 0, + }), + end: Some(Timestamp { + seconds: 120000, + nanos: 0, + }), + slot: "slot-0".to_string(), + }], + }), + }; + + for _ in 0..20 { + tx1.send(rr.clone()).await.unwrap(); + sleep(Duration::from_millis(10)).await; + } + }); + + tokio::spawn(async move { + let rr = reduce::proto::ReduceRequest { + payload: Some(reduce::proto::reduce_request::Payload { + keys: vec!["key2".to_string()], + value: vec![], + watermark: None, + event_time: None, + headers: Default::default(), + }), + operation: Some(reduce::proto::reduce_request::WindowOperation { + event: 0, + windows: vec![reduce::proto::Window { + start: Some(Timestamp { + seconds: 60000, + nanos: 0, + }), + end: Some(Timestamp { + seconds: 120000, + nanos: 0, + }), + slot: "slot-0".to_string(), + }], + }), + }; + + for _ in 0..10 { + tx2.send(rr.clone()).await.unwrap(); + sleep(Duration::from_millis(10)).await; + } + }); + + // Convert the receiver end of the channel into a stream + let stream1 = ReceiverStream::new(rx1); + + let stream2 = ReceiverStream::new(rx2); + + // Create a tonic::Request from the stream + let request1 = Request::new(stream1); + + let request2 = Request::new(stream2); + + let mut first_client = client.clone(); + tokio::spawn(async move { + let mut response_stream = first_client.reduce_fn(request1).await.unwrap().into_inner(); + assert!(response_stream.message().await.is_ok()); + }); + + let mut second_client = client.clone(); + tokio::spawn(async move { + let mut response_stream = second_client + .reduce_fn(request2) + .await + .unwrap() + .into_inner(); + + if let Err(e) = response_stream.message().await { + assert_eq!(e.code(), tonic::Code::Internal); + assert!(e.message().contains("User Defined Error")) + } + }); + + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(100)).await; + if task.is_finished() { + break; + } + } + assert!(task.is_finished(), "gRPC server is still running"); + + Ok(()) + } } diff --git a/src/shared.rs b/src/shared.rs index 9d4a88a..6a4123c 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -8,14 +8,13 @@ use tokio::net::UnixListener; use tokio::signal; use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::UnixListenerStream; -use tokio_util::sync::CancellationToken; use tracing::info; // #[tracing::instrument(skip(path), fields(path = ?path.as_ref()))] #[tracing::instrument(fields(path = ? path.as_ref()))] fn write_info_file(path: impl AsRef) -> io::Result<()> { let parent = path.as_ref().parent().unwrap(); - std::fs::create_dir_all(parent)?; + fs::create_dir_all(parent)?; // TODO: make port-number and CPU meta-data configurable, e.g., ("CPU_LIMIT", "1") let metadata: HashMap = HashMap::new(); @@ -57,18 +56,13 @@ pub(crate) fn prost_timestamp_from_utc(t: DateTime) -> Option { /// shuts downs the gRPC server. This happens in 2 cases /// 1. there has been an internal error (one of the tasks failed) and we need to shutdown -/// 2. user is explictly asking us to shutdown +/// 2. user is explicitly asking us to shutdown /// Once the request for shutdown has be invoked, server will broadcast shutdown to all tasks /// through the cancellation-token. pub(crate) async fn shutdown_signal( mut shutdown_on_err: mpsc::Receiver<()>, shutdown_from_user: Option>, - cancel_token: CancellationToken, ) { - // will call cancel_token.cancel() when the function exits - // because of abort request, ctrl-c, or SIGTERM signal - let _drop_guard = cancel_token.drop_guard(); - let ctrl_c = async { signal::ctrl_c() .await @@ -180,12 +174,7 @@ mod tests { // Spawn a new task to call shutdown_signal let shutdown_signal_task = tokio::spawn(async move { - shutdown_signal( - internal_shutdown_rx, - Some(user_shutdown_rx), - CancellationToken::new(), - ) - .await; + shutdown_signal(internal_shutdown_rx, Some(user_shutdown_rx)).await; }); // Send a shutdown signal diff --git a/src/sideinput.rs b/src/sideinput.rs index b5d3bb6..6256a46 100644 --- a/src/sideinput.rs +++ b/src/sideinput.rs @@ -1,13 +1,14 @@ +use crate::error::Error::SideInputError; +use crate::error::ErrorKind::{InternalError, UserDefinedError}; +use crate::shared; +use crate::shared::shutdown_signal; use std::fs; use std::path::PathBuf; - +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; -use crate::shared; -use crate::shared::shutdown_signal; - const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sideinput.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sideinput-server-info"; @@ -17,8 +18,9 @@ mod proto { } struct SideInputService { - handler: T, - _shutdown_tx: mpsc::Sender<()>, + handler: Arc, + shutdown_tx: mpsc::Sender<()>, + cancellation_token: CancellationToken, } /// The `SideInputer` trait defines a method for retrieving side input data. @@ -94,19 +96,35 @@ where &self, _: Request<()>, ) -> Result, Status> { - let msg = self.handler.retrieve_sideinput().await; - let si = match msg { - Some(value) => proto::SideInputResponse { - value, - no_broadcast: false, + let handler = Arc::clone(&self.handler); + let shutdown_tx = self.shutdown_tx.clone(); + let handle = tokio::spawn(async move { handler.retrieve_sideinput().await }); + + tokio::select! { + msg = handle => { + match msg { + Ok(Some(value)) => { + Ok(Response::new(proto::SideInputResponse { + value, + no_broadcast: false, + })) + } + Ok(None) => { + Ok(Response::new(proto::SideInputResponse { + value: Vec::new(), + no_broadcast: true, + })) + } + Err(e) => { + shutdown_tx.send(()).await.expect("Failed to send shutdown signal"); + Err(Status::internal(SideInputError(UserDefinedError(e.to_string())).to_string())) + } + } + } + _ = self.cancellation_token.cancelled() => { + Err(Status::internal(SideInputError(InternalError("Server is shutting down".to_string())).to_string())) }, - None => proto::SideInputResponse { - value: Vec::new(), - no_broadcast: true, - }, - }; - - Ok(Response::new(si)) + } } async fn is_ready(&self, _: Request<()>) -> Result, Status> { @@ -179,20 +197,21 @@ impl Server { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + let cln_token = CancellationToken::new(); let sideinput_svc = SideInputService { - handler, - _shutdown_tx: internal_shutdown_tx, + handler: Arc::new(handler), + shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), }; let sideinput_svc = proto::side_input_server::SideInputServer::new(sideinput_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); - let shutdown = shutdown_signal( - internal_shutdown_rx, - Some(shutdown_rx), - CancellationToken::new(), - ); + let shutdown = shutdown_signal(internal_shutdown_rx, Some(shutdown_rx)); + + // will call cancel_token.cancel() on drop of _drop_guard + let _drop_guard = cln_token.drop_guard(); tonic::transport::Server::builder() .add_service(sideinput_svc) diff --git a/src/sink.rs b/src/sink.rs index aed4b53..618bbee 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -1,14 +1,15 @@ +use crate::error::Error::SinkError; +use crate::error::ErrorKind::{InternalError, UserDefinedError}; +use crate::shared; +use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; use std::{env, fs}; - -use chrono::{DateTime, Utc}; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; use tonic::{Request, Status, Streaming}; -use crate::shared; - const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sink.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sinker-server-info"; @@ -24,8 +25,9 @@ pub mod proto { } struct SinkService { - handler: T, - _shutdown_tx: mpsc::Sender<()>, + handler: Arc, + shutdown_tx: mpsc::Sender<()>, + cancellation_token: CancellationToken, } /// Sinker trait for implementing user defined sinks. @@ -180,33 +182,66 @@ where request: Request>, ) -> Result, Status> { let mut stream = request.into_inner(); - + let sink_handle = self.handler.clone(); + let cancellation_token = self.cancellation_token.clone(); + let shutdown_tx = self.shutdown_tx.clone(); // TODO: what should be the idle buffer size? let (tx, rx) = mpsc::channel::(1); - // call the user's sink handle - let sink_handle = self.handler.sink(rx); - - // write to the user-defined channel - tokio::spawn(async move { - while let Some(next_message) = stream - .message() - .await - .expect("expected next message from stream") - { - // FIXME: panic is very bad idea! - tx.send(next_message.into()) - .await - .expect("send be successfully received!"); + let reader_shutdown_tx = shutdown_tx.clone(); + // spawn a task to read messages from the stream and send them to the user's sink handle + let reader_handle = tokio::spawn(async move { + loop { + match stream.message().await { + Ok(Some(message)) => { + // If sending fails, it means the receiver is dropped, and we should stop the task. + if let Err(e) = tx.send(message.into()).await { + tracing::error!("Failed to send message: {}", e); + break; + } + } + // If there's an error or the stream ends, break the loop to stop the task. + Ok(None) => break, + Err(e) => { + tracing::error!("Error reading message from stream: {}", e); + reader_shutdown_tx + .send(()) + .await + .expect("Sending shutdown signal to gRPC server"); + break; + } + } } }); - // wait for the sink handle to respond - let responses = sink_handle.await; + // call the user's sink handle + let handle = tokio::spawn(async move { sink_handle.sink(rx).await }); + + // Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled), + // then return an error. + tokio::select! { + result = handle => { + match result { + Ok(responses) => { + Ok(tonic::Response::new(proto::SinkResponse { + results: responses.into_iter().map(|r| r.into()).collect(), + })) + } + Err(e) => { + // Send a shutdown signal to the server to do a graceful shutdown because there was + // a panic in the handler. + shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server"); + Err(Status::internal(SinkError(UserDefinedError(e.to_string())).to_string())) + } + } + }, - Ok(tonic::Response::new(proto::SinkResponse { - results: responses.into_iter().map(|r| r.into()).collect(), - })) + _ = cancellation_token.cancelled() => { + // abort the reader task to stop reading messages from the stream + reader_handle.abort(); + Err(Status::cancelled(SinkError(InternalError("Server is shutting down".to_string())).to_string())) + } + } } async fn is_ready( @@ -269,7 +304,7 @@ impl Server { self.max_message_size } - /// Change the file in which numflow server information is stored on start up to the new value. Default value is `/var/run/numaflow/sinker-server-info` + /// Change the file in which numaflow server information is stored on start up to the new value. Default value is `/var/run/numaflow/sinker-server-info` pub fn with_server_info_file(mut self, file: impl Into) -> Self { self.server_info_file = file.into(); self @@ -290,22 +325,23 @@ impl Server { { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); + let cln_token = CancellationToken::new(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); let svc = SinkService { - handler, - _shutdown_tx: internal_shutdown_tx, + handler: Arc::new(handler), + shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), }; let svc = proto::sink_server::SinkServer::new(svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); - let shutdown = shared::shutdown_signal( - internal_shutdown_rx, - Some(shutdown_rx), - CancellationToken::new(), - ); + let shutdown = shared::shutdown_signal(internal_shutdown_rx, Some(shutdown_rx)); + + // will call cancel_token.cancel() on drop of _drop_guard + let _drop_guard = cln_token.drop_guard(); tonic::transport::Server::builder() .add_service(svc) @@ -436,4 +472,92 @@ mod tests { assert!(task.is_finished(), "gRPC server is still running"); Ok(()) } + + #[tokio::test] + async fn sink_panic() -> Result<(), Box> { + struct PanicSink; + #[tonic::async_trait] + impl sink::Sinker for PanicSink { + async fn sink( + &self, + mut input: tokio::sync::mpsc::Receiver, + ) -> Vec { + let mut responses: Vec = Vec::new(); + let mut count = 0; + + while let Some(datum) = input.recv().await { + if count > 5 { + panic!("Should not cross 5"); + } + count += 1; + responses.push(sink::Response::ok(datum.id)); + } + responses + } + } + + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("sink.sock"); + let server_info_file = tmp_dir.path().join("sinker-server-info"); + + let mut server = sink::Server::new(PanicSink) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes + let sock_file = sock_file.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + UnixStream::connect(sock_file).await?, + )) + } + })) + .await?; + + let mut client = SinkClient::new(channel); + let mut requests = Vec::new(); + + for i in 0..10 { + let request = sink::proto::SinkRequest { + keys: vec!["first".into(), "second".into()], + value: format!("hello {}", i).into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + id: i.to_string(), + headers: Default::default(), + }; + requests.push(request); + } + + let resp = client.sink_fn(tokio_stream::iter(requests)).await; + assert!(resp.is_err(), "Expected error from server"); + + if let Err(e) = resp { + assert_eq!(e.code(), tonic::Code::Internal); + assert!(e.message().contains("User Defined Error")); + } + + // server should shut down gracefully because there was a panic in the handler. + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } } diff --git a/src/source.rs b/src/source.rs index 2d8b29e..abd8d41 100644 --- a/src/source.rs +++ b/src/source.rs @@ -4,6 +4,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use crate::shared::{self, prost_timestamp_from_utc}; use chrono::{DateTime, Utc}; use tokio::sync::mpsc::{self, Sender}; use tokio::sync::oneshot; @@ -11,8 +12,6 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; -use crate::shared::{self, prost_timestamp_from_utc}; - const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/source.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcer-server-info"; @@ -25,6 +24,7 @@ pub mod proto { struct SourceService { handler: Arc, _shutdown_tx: Sender<()>, + _cancellation_token: CancellationToken, } #[async_trait] @@ -267,21 +267,22 @@ impl Server { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + let cln_token = CancellationToken::new(); let source_service = SourceService { handler: Arc::new(handler), _shutdown_tx: internal_shutdown_tx, + _cancellation_token: cln_token.clone(), }; let source_svc = proto::source_server::SourceServer::new(source_service) .max_decoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); - let shutdown = shared::shutdown_signal( - internal_shutdown_rx, - Some(shutdown_rx), - CancellationToken::new(), - ); + let shutdown = shared::shutdown_signal(internal_shutdown_rx, Some(shutdown_rx)); + + // will call cancel_token.cancel() on drop of _drop_guard + let _drop_guard = cln_token.drop_guard(); tonic::transport::Server::builder() .add_service(source_svc) diff --git a/src/sourcetransform.rs b/src/sourcetransform.rs index 5e06eea..b95fe72 100644 --- a/src/sourcetransform.rs +++ b/src/sourcetransform.rs @@ -1,14 +1,15 @@ +use crate::error::Error::SourceTransformerError; +use crate::error::ErrorKind::UserDefinedError; +use crate::shared::{self, prost_timestamp_from_utc}; +use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::fs; use std::path::PathBuf; - -use chrono::{DateTime, Utc}; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; -use crate::shared::{self, prost_timestamp_from_utc}; - const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sourcetransform.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcetransformer-server-info"; @@ -21,8 +22,9 @@ pub mod proto { } struct SourceTransformerService { - handler: T, - _shutdown_tx: mpsc::Sender<()>, + handler: Arc, + shutdown_tx: mpsc::Sender<()>, + cancellation_token: CancellationToken, } /// SourceTransformer trait for implementing SourceTransform handler. @@ -238,15 +240,32 @@ where request: Request, ) -> Result, Status> { let request = request.into_inner(); - - let messages = self.handler.transform(request.into()).await; - - Ok(Response::new(proto::SourceTransformResponse { - results: messages - .into_iter() - .map(|msg| msg.into()) - .collect::>(), - })) + let handler = Arc::clone(&self.handler); + let handle = tokio::spawn(async move { handler.transform(request.into()).await }); + let shutdown_tx = self.shutdown_tx.clone(); + let cancellation_token = self.cancellation_token.clone(); + + // Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled), + // then return an error. + tokio::select! { + result = handle => { + match result { + Ok(messages) => Ok(Response::new(proto::SourceTransformResponse { + results: messages.into_iter().map(|msg| msg.into()).collect(), + })), + Err(e) => { + tracing::error!("Error in source transform handler: {:?}", e); + // Send a shutdown signal to the server to do a graceful shutdown because there was + // a panic in the handler. + shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server"); + Err(Status::internal(SourceTransformerError(UserDefinedError(e.to_string())).to_string())) + } + } + }, + _ = cancellation_token.cancelled() => { + Err(Status::internal(SourceTransformerError(UserDefinedError("Server is shutting down".to_string())).to_string())) + }, + } } async fn is_ready(&self, _: Request<()>) -> Result, Status> { @@ -318,21 +337,22 @@ impl Server { let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + let cln_token = CancellationToken::new(); let sourcetrf_svc = SourceTransformerService { - handler, - _shutdown_tx: internal_shutdown_tx, + handler: Arc::new(handler), + shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), }; let sourcetrf_svc = proto::source_transform_server::SourceTransformServer::new(sourcetrf_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); - let shutdown = shared::shutdown_signal( - internal_shutdown_rx, - Some(shutdown_rx), - CancellationToken::new(), - ); + let shutdown = shared::shutdown_signal(internal_shutdown_rx, Some(shutdown_rx)); + + // will call cancel_token.cancel() on drop of _drop_guard + let _drop_guard = cln_token.drop_guard(); tonic::transport::Server::builder() .add_service(sourcetrf_svc) @@ -374,7 +394,7 @@ mod tests { use crate::sourcetransform::proto::source_transform_client::SourceTransformClient; #[tokio::test] - async fn sourcetransformer_server() -> Result<(), Box> { + async fn source_transformer_server() -> Result<(), Box> { struct NowCat; #[tonic::async_trait] impl sourcetransform::SourceTransformer for NowCat { @@ -445,4 +465,76 @@ mod tests { assert!(task.is_finished(), "gRPC server is still running"); Ok(()) } + + #[tokio::test] + async fn source_transformer_panic() -> Result<(), Box> { + struct PanicTransformer; + #[tonic::async_trait] + impl sourcetransform::SourceTransformer for PanicTransformer { + async fn transform( + &self, + _: sourcetransform::SourceTransformRequest, + ) -> Vec { + panic!("Panic in transformer"); + } + } + + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("sourcetransform.sock"); + let server_info_file = tmp_dir.path().join("sourcetransformer-server-info"); + + let mut server = sourcetransform::Server::new(PanicTransformer) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes + let sock_file = sock_file.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + UnixStream::connect(sock_file).await?, + )) + } + })) + .await?; + + let mut client = SourceTransformClient::new(channel); + let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest { + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }); + + let resp = client.source_transform_fn(request).await; + assert!(resp.is_err(), "Expected error from server"); + + if let Err(e) = resp { + assert_eq!(e.code(), tonic::Code::Internal); + assert!(e.message().contains("User Defined Error")); + } + + // server should shut down gracefully because there was a panic in the handler. + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } }