From 243c3a12c8cf081bd673d2e2353b81587cce584b Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Mon, 22 Jul 2024 18:27:31 -0700 Subject: [PATCH] add makefile, cancellations Signed-off-by: Sidhant Kohli --- Makefile | 17 ++++++++ examples/batchmap-cat/src/main.rs | 18 ++++---- examples/reduce-counter/src/main.rs | 8 ++-- examples/sideinput/src/main.rs | 7 +-- examples/sideinput/udf/src/main.rs | 2 +- src/batchmap.rs | 67 ++++++++++++++++------------- 6 files changed, 73 insertions(+), 46 deletions(-) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2e7ee8d --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +# Description: Makefile for Rust projects + +# perform a cargo fmt on all directories containing a Cargo.toml file +.PHONY: lint +# find all directories containing Cargo.toml files +DIRS := $(shell find . -type f -name Cargo.toml -exec dirname {} \; | sort -u) +$(info Included directories: $(DIRS)) +lint: + @for dir in $(DIRS); do \ + echo "Formatting code in $$dir"; \ + cargo fmt --all --manifest-path "$$dir/Cargo.toml"; \ + done + +# run cargo test on the repository root +.PHONY: test +test: + cargo test --workspace diff --git a/examples/batchmap-cat/src/main.rs b/examples/batchmap-cat/src/main.rs index c1d208c..9bf4c73 100644 --- a/examples/batchmap-cat/src/main.rs +++ b/examples/batchmap-cat/src/main.rs @@ -12,15 +12,15 @@ struct Cat; impl batchmap::BatchMapper for Cat { async fn batchmap(&self, mut input: tokio::sync::mpsc::Receiver) -> Vec { let mut responses: Vec = Vec::new(); - while let Some(datum) = input.recv().await { - let mut response = BatchResponse::from_id(datum.id); - response.append(Message { - keys: Option::from(datum.keys), - value: datum.value, - tags: None, - }); - responses.push(response); - } + while let Some(datum) = input.recv().await { + let mut response = BatchResponse::from_id(datum.id); + response.append(Message { + keys: Some(datum.keys), + value: datum.value, + tags: None, + }); + responses.push(response); + } responses } } diff --git a/examples/reduce-counter/src/main.rs b/examples/reduce-counter/src/main.rs index 77fa50b..83146ab 100644 --- a/examples/reduce-counter/src/main.rs +++ b/examples/reduce-counter/src/main.rs @@ -9,7 +9,7 @@ async fn main() -> Result<(), Box> { mod counter { use numaflow::reduce::{Message, ReduceRequest}; - use numaflow::reduce::{Reducer, Metadata}; + use numaflow::reduce::{Metadata, Reducer}; use tokio::sync::mpsc::Receiver; use tonic::async_trait; @@ -44,8 +44,10 @@ mod counter { while input.recv().await.is_some() { counter += 1; } - let message = Message::new(counter.to_string().into_bytes()).tags(vec![]).keys(keys.clone()); + let message = Message::new(counter.to_string().into_bytes()) + .tags(vec![]) + .keys(keys.clone()); vec![message] } } -} \ No newline at end of file +} diff --git a/examples/sideinput/src/main.rs b/examples/sideinput/src/main.rs index 448a152..39af6dd 100644 --- a/examples/sideinput/src/main.rs +++ b/examples/sideinput/src/main.rs @@ -1,7 +1,6 @@ +use numaflow::sideinput::{self, SideInputer}; use std::sync::Mutex; use std::time::{SystemTime, UNIX_EPOCH}; -use numaflow::sideinput::{self, SideInputer}; - use tonic::async_trait; @@ -37,5 +36,7 @@ impl SideInputer for SideInputHandler { #[tokio::main] async fn main() -> Result<(), Box> { - sideinput::Server::new(SideInputHandler::new()).start().await + sideinput::Server::new(SideInputHandler::new()) + .start() + .await } diff --git a/examples/sideinput/udf/src/main.rs b/examples/sideinput/udf/src/main.rs index c998b76..f8d3bb5 100644 --- a/examples/sideinput/udf/src/main.rs +++ b/examples/sideinput/udf/src/main.rs @@ -1,7 +1,7 @@ use std::path::Path; use notify::{RecursiveMode, Result, Watcher}; -use numaflow::map::{Mapper, MapRequest, Message, Server}; +use numaflow::map::{MapRequest, Mapper, Message, Server}; use tokio::spawn; use tonic::async_trait; diff --git a/src/batchmap.rs b/src/batchmap.rs index 6555767..9893988 100644 --- a/src/batchmap.rs +++ b/src/batchmap.rs @@ -26,6 +26,7 @@ pub mod proto { struct BatchMapService { handler: T, _shutdown_tx: mpsc::Sender<()>, + cancellation_token: CancellationToken, } /// BatchMapper trait for implementing batch mode user defined function. @@ -263,6 +264,7 @@ where channel::>(1); let shutdown_tx = self._shutdown_tx.clone(); + let writer_cln_token = self.cancellation_token.clone(); // counter to keep track of the number of messages received let counter_orig = Arc::new(AtomicUsize::new(0)); @@ -271,16 +273,26 @@ where let counter = counter_orig.clone(); // write to the user-defined channel tokio::spawn(async move { - while let Some(next_message) = stream - .message() - .await - .expect("expected next message from stream") - { - let datum = Datum::from(next_message); - tx.send(datum) - .await - .expect("send be successfully received!"); - counter.fetch_add(1, Ordering::Relaxed); + loop { + tokio::select! { + next_message = stream.message() => { + match next_message { + Ok(Some(message)) => { + let datum = Datum::from(message); + if tx.send(datum).await.is_err() { + break; + } + counter.fetch_add(1, Ordering::Relaxed); + }, + // If there's an error or the stream ends, break the loop to stop the task. + Ok(None) | Err(_) => break, + } + }, + // Listen for cancellation. If triggered, break the loop to stop reading new messages. + _ = writer_cln_token.cancelled() => { + break; + } + } } }); @@ -315,11 +327,9 @@ where })) .await; // if the send fails, return an error status on the streaming endpoint - if send_result.is_err() { + if let Err(e) = send_result { grpc_response_tx - .send(Err(Status::internal( - send_result.err().unwrap().to_string(), - ))) + .send(Err(Status::internal(e.to_string()))) .await .expect("send to grpc response channel failed"); return; @@ -400,22 +410,21 @@ impl crate::batchmap::Server { shared::create_listener_stream(&self.sock_addr, &self.server_info_file, info)?; let handler = self.svc.take().unwrap(); + let cln_token = CancellationToken::new(); + // Create a channel to send shutdown signal to the server to do graceful shutdown in case of non retryable errors. let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); let map_svc = crate::batchmap::BatchMapService { handler, _shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), }; let map_svc = proto::batch_map_server::BatchMapServer::new(map_svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); - let shutdown = shutdown_signal( - internal_shutdown_rx, - Some(shutdown_rx), - CancellationToken::new(), - ); + let shutdown = shutdown_signal(internal_shutdown_rx, Some(shutdown_rx), cln_token); tonic::transport::Server::builder() .add_service(map_svc) @@ -608,18 +617,16 @@ mod tests { .await?; let mut r = resp.into_inner(); - let mut error_flag = false; + let Err(server_err) = r.message().await else { + return Err("Expected error from server".into()); + }; + + assert_eq!(server_err.code(), tonic::Code::Internal); + assert!(server_err.message().contains( + "number of responses does not \ + match the number of messages received" + )); - if let Err(e) = r.message().await { - assert_eq!(e.code(), tonic::Code::Internal); - assert!(e.message().contains( - "number of responses does not \ - match the number of messages received" - )); - error_flag = true; - } - // Check if the error flag is set - assert!(error_flag, "Expected error from server"); tokio::time::sleep(Duration::from_millis(50)).await; assert!(task.is_finished(), "gRPC server is still running"); Ok(())