From d4c8494b0d5c12d0b4731fdba9ef31446992e6b8 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Fri, 4 Oct 2024 15:26:51 +0530 Subject: [PATCH 1/2] feat: Use gRPC Bidirectional Streaming for Map Signed-off-by: Yashash H L --- proto/map.proto | 28 ++- src/batchmap.rs | 25 +-- src/map.rs | 490 ++++++++++++++++++++++++++++++----------- src/shared.rs | 3 +- src/sourcetransform.rs | 58 ++--- 5 files changed, 429 insertions(+), 175 deletions(-) diff --git a/proto/map.proto b/proto/map.proto index 07433dd..f3761d1 100644 --- a/proto/map.proto +++ b/proto/map.proto @@ -7,7 +7,7 @@ package map.v1; service Map { // MapFn applies a function to each map request element. - rpc MapFn(MapRequest) returns (MapResponse); + rpc MapFn(stream MapRequest) returns (stream MapResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); @@ -17,12 +17,25 @@ service Map { * MapRequest represents a request element. */ message MapRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + } + Request request = 1; + // This ID is used to uniquely identify a map request + string id = 2; + optional Handshake handshake = 3; +} +/* + * Handshake message between client and server to indicate the start of transmission. + */ +message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; } /** @@ -35,6 +48,9 @@ message MapResponse { repeated string tags = 3; } repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; + optional Handshake handshake = 3; } /** diff --git a/src/batchmap.rs b/src/batchmap.rs index 7109fe9..9814a59 100644 --- a/src/batchmap.rs +++ b/src/batchmap.rs @@ -118,7 +118,7 @@ pub struct Message { } /// Represents a message that can be modified and forwarded. -impl crate::batchmap::Message { +impl Message { /// Creates a new message with the specified value. /// /// This constructor initializes the message with no keys, tags, or specific event time. @@ -148,11 +148,11 @@ impl crate::batchmap::Message { /// use numaflow::batchmap::Message; /// let dropped_message = Message::message_to_drop(); /// ``` - pub fn message_to_drop() -> crate::batchmap::Message { - crate::batchmap::Message { + pub fn message_to_drop() -> Message { + Message { keys: None, value: vec![], - tags: Some(vec![crate::batchmap::DROP.to_string()]), + tags: Some(vec![DROP.to_string()]), } } @@ -245,11 +245,8 @@ impl BatchMap for BatchMapService where T: BatchMapper + Send + Sync + 'static, { - async fn is_ready( - &self, - _: Request<()>, - ) -> Result, Status> { - Ok(tonic::Response::new(proto::ReadyResponse { ready: true })) + async fn is_ready(&self, _: Request<()>) -> Result, Status> { + Ok(Response::new(proto::ReadyResponse { ready: true })) } type BatchMapFnStream = ReceiverStream>; @@ -261,7 +258,7 @@ where let mut stream = request.into_inner(); // Create a channel to send the messages to the user defined function. - let (tx, rx) = mpsc::channel::(1); + let (tx, rx) = channel::(1); // Create a channel to send the response back to the grpc client. let (grpc_response_tx, grpc_response_rx) = @@ -418,9 +415,9 @@ pub struct Server { server_info_file: PathBuf, svc: Option, } -impl crate::batchmap::Server { +impl Server { pub fn new(batch_map_svc: T) -> Self { - crate::batchmap::Server { + Server { sock_addr: DEFAULT_SOCK_ADDR.into(), max_message_size: DEFAULT_MAX_MESSAGE_SIZE, server_info_file: DEFAULT_SERVER_INFO_FILE.into(), @@ -478,8 +475,8 @@ impl crate::batchmap::Server { let cln_token = CancellationToken::new(); // Create a channel to send shutdown signal to the server to do graceful shutdown in case of non retryable errors. - let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); - let map_svc = crate::batchmap::BatchMapService { + let (internal_shutdown_tx, internal_shutdown_rx) = channel(1); + let map_svc = BatchMapService { handler: Arc::new(handler), _shutdown_tx: internal_shutdown_tx, cancellation_token: cln_token.clone(), diff --git a/src/map.rs b/src/map.rs index eec2294..2ba1fca 100644 --- a/src/map.rs +++ b/src/map.rs @@ -1,5 +1,5 @@ -use crate::error::Error::MapError; -use crate::error::ErrorKind::{InternalError, UserDefinedError}; +use crate::error::{Error, ErrorKind}; +use crate::map::proto::MapResponse; use crate::shared::{self, shutdown_signal, ContainerType}; use chrono::{DateTime, Utc}; use std::collections::HashMap; @@ -7,9 +7,13 @@ use std::fs; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; -use tonic::{async_trait, Request, Response, Status}; +use tonic::{async_trait, Request, Response, Status, Streaming}; +use tracing::{error, info}; +const DEFAULT_CHANNEL_SIZE: usize = 1000; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/map.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/mapper-server-info"; @@ -60,53 +64,6 @@ pub trait Mapper { async fn map(&self, input: MapRequest) -> Vec; } -#[async_trait] -impl proto::map_server::Map for MapService -where - T: Mapper + Send + Sync + 'static, -{ - async fn map_fn( - &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); - let handler = Arc::clone(&self.handler); - let handle = tokio::spawn(async move { handler.map(request.into()).await }); - let shutdown_tx = self.shutdown_tx.clone(); - let cancellation_token = self.cancellation_token.clone(); - - // Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled), - // then return an error. - tokio::select! { - result = handle => { - match result { - Ok(result) => Ok(Response::new(proto::MapResponse { - results: result.into_iter().map(|msg| msg.into()).collect(), - })), - Err(e) => { - tracing::error!("Error in map handler: {:?}", e); - // Send a shutdown signal to the server to do a graceful shutdown because there was - // a panic in the handler. - shutdown_tx - .send(()) - .await - .expect("Sending shutdown signal to gRPC server"); - Err(Status::internal(MapError(UserDefinedError(e.to_string())).to_string())) - } - } - }, - - _ = cancellation_token.cancelled() => { - Err(Status::internal(MapError(InternalError("Server is shutting down".to_string())).to_string())) - }, - } - } - - async fn is_ready(&self, _: Request<()>) -> Result, Status> { - Ok(Response::new(proto::ReadyResponse { ready: true })) - } -} - /// Message is the response struct from the [`Mapper::map`] . #[derive(Debug, PartialEq)] pub struct Message { @@ -234,8 +191,8 @@ pub struct MapRequest { pub headers: HashMap, } -impl From for MapRequest { - fn from(value: proto::MapRequest) -> Self { +impl From for MapRequest { + fn from(value: proto::map_request::Request) -> Self { Self { keys: value.keys, value: value.value, @@ -246,6 +203,235 @@ impl From for MapRequest { } } +#[async_trait] +impl proto::map_server::Map for MapService +where + T: Mapper + Send + Sync + 'static, +{ + type MapFnStream = ReceiverStream>; + + async fn map_fn( + &self, + request: Request>, + ) -> Result, Status> { + let mut stream = request.into_inner(); + let handler = Arc::clone(&self.handler); + + let (stream_response_tx, stream_response_rx) = + mpsc::channel::>(DEFAULT_CHANNEL_SIZE); + + // perform handshake + perform_handshake(&mut stream, &stream_response_tx).await?; + + let (error_tx, error_rx) = mpsc::channel::(1); + + // Spawn a task to handle incoming stream requests + let handle: JoinHandle<()> = tokio::spawn(handle_stream_requests( + handler.clone(), + stream, + stream_response_tx.clone(), + error_tx.clone(), + self.cancellation_token.child_token(), + )); + + tokio::spawn(manage_grpc_stream( + handle, + self.cancellation_token.clone(), + stream_response_tx, + error_rx, + self.shutdown_tx.clone(), + )); + + Ok(Response::new(ReceiverStream::new(stream_response_rx))) + } + + async fn is_ready(&self, _: Request<()>) -> Result, Status> { + Ok(Response::new(proto::ReadyResponse { ready: true })) + } +} + +async fn handle_stream_requests( + handler: Arc, + mut stream: Streaming, + stream_response_tx: mpsc::Sender>, + error_tx: mpsc::Sender, + token: CancellationToken, +) where + T: Mapper + Send + Sync + 'static, +{ + let mut stream_open = true; + while stream_open { + stream_open = tokio::select! { + map_request = stream.message() => handle_request( + handler.clone(), + map_request, + stream_response_tx.clone(), + error_tx.clone(), + token.clone(), + ).await, + _ = token.cancelled() => { + info!("Cancellation token is cancelled, shutting down"); + break; + } + } + } +} + +async fn manage_grpc_stream( + request_handler: JoinHandle<()>, + token: CancellationToken, + stream_response_tx: mpsc::Sender>, + mut error_rx: mpsc::Receiver, + server_shutdown_tx: mpsc::Sender<()>, +) { + let err = tokio::select! { + _ = request_handler => { + token.cancel(); + return; + }, + err = error_rx.recv() => err, + }; + + token.cancel(); + let Some(err) = err else { + return; + }; + error!("Shutting down gRPC channel: {err:?}"); + stream_response_tx + .send(Err(Status::internal(err.to_string()))) + .await + .expect("Sending error message to gRPC response channel"); + server_shutdown_tx + .send(()) + .await + .expect("Writing to shutdown channel"); +} + +async fn handle_request( + handler: Arc, + map_request: Result, Status>, + stream_response_tx: mpsc::Sender>, + error_tx: mpsc::Sender, + token: CancellationToken, +) -> bool +where + T: Mapper + Send + Sync + 'static, +{ + let map_request = match map_request { + Ok(None) => return false, + Ok(Some(val)) => val, + Err(val) => { + error!("Received gRPC error from sender: {val:?}"); + return false; + } + }; + tokio::spawn(run_map( + handler, + map_request, + stream_response_tx, + error_tx, + token, + )); + true +} + +async fn run_map( + handler: Arc, + map_request: proto::MapRequest, + stream_response_tx: mpsc::Sender>, + error_tx: mpsc::Sender, + token: CancellationToken, +) where + T: Mapper + Send + Sync + 'static, +{ + let Some(request) = map_request.request else { + error_tx + .send(Error::MapError(ErrorKind::InternalError( + "Request not present".to_string(), + ))) + .await + .expect("Sending error on channel"); + return; + }; + + let message_id = map_request.id.clone(); + + // A new task is spawned to catch the panic + let udf_map_task = tokio::spawn({ + let handler = handler.clone(); + let token = token.child_token(); + async move { + tokio::select! { + _ = token.cancelled() => None, + messages = handler.map(request.into()) => Some(messages), + } + } + }); + + let messages = match udf_map_task.await { + Ok(messages) => messages, + Err(e) => { + error!("Failed to run map function: {e:?}"); + error_tx + .send(Error::MapError(ErrorKind::InternalError(format!( + "panicked: {e:?}" + )))) + .await + .expect("Sending error on channel"); + return; + } + }; + + let Some(messages) = messages else { + // CancellationToken is cancelled + return; + }; + + let send_response_result = stream_response_tx + .send(Ok(MapResponse { + results: messages.into_iter().map(|msg| msg.into()).collect(), + id: message_id, + handshake: None, + })) + .await; + + let Err(e) = send_response_result else { + return; + }; + + error_tx + .send(Error::MapError(ErrorKind::InternalError(format!( + "Failed to send response: {e:?}" + )))) + .await + .expect("Sending error on channel"); +} + +async fn perform_handshake( + stream: &mut Streaming, + stream_response_tx: &mpsc::Sender>, +) -> Result<(), Status> { + let handshake_request = stream + .message() + .await + .map_err(|e| Status::internal(format!("Handshake failed: {}", e)))? + .ok_or_else(|| Status::internal("Stream closed before handshake"))?; + + if let Some(handshake) = handshake_request.handshake { + stream_response_tx + .send(Ok(MapResponse { + results: vec![], + id: "".to_string(), + handshake: Some(handshake), + })) + .await + .map_err(|e| Status::internal(format!("Failed to send handshake response: {}", e)))?; + Ok(()) + } else { + Err(Status::invalid_argument("Handshake not present")) + } +} + /// gRPC server to start a map service #[derive(Debug)] pub struct Server { @@ -362,10 +548,11 @@ mod tests { use crate::map::proto::map_client::MapClient; use std::{error::Error, time::Duration}; + use crate::map::proto; use tempfile::TempDir; use tokio::net::UnixStream; - use tokio::sync::oneshot; - use tokio::time::sleep; + use tokio::sync::{mpsc, oneshot}; + use tokio_stream::wrappers::ReceiverStream; use tonic::transport::Uri; use tower::service_fn; @@ -415,21 +602,51 @@ mod tests { .await?; let mut client = MapClient::new(channel); - let request = tonic::Request::new(map::proto::MapRequest { - keys: vec!["first".into(), "second".into()], - value: "hello".into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - headers: Default::default(), - }); - - let resp = client.map_fn(request).await?; - let resp = resp.into_inner(); - assert_eq!(resp.results.len(), 1, "Expected single message from server"); - let msg = &resp.results[0]; + let request = proto::MapRequest { + request: Some(proto::map_request::Request { + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }), + id: "".to_string(), + handshake: None, + }; + + let (tx, rx) = mpsc::channel(2); + let handshake_request = proto::MapRequest { + request: None, + id: "".to_string(), + handshake: Some(proto::Handshake { sot: true }), + }; + + tx.send(handshake_request).await?; + tx.send(request).await?; + + let resp = client.map_fn(ReceiverStream::new(rx)).await?; + let mut resp = resp.into_inner(); + + let handshake_response = resp.message().await?; + assert!(handshake_response.is_some()); + + let handshake_response = handshake_response.unwrap(); + assert!(handshake_response.handshake.is_some()); + + let actual_response = resp.message().await?; + assert!(actual_response.is_some()); + + let actual_response = actual_response.unwrap(); + assert_eq!( + actual_response.results.len(), + 1, + "Expected single message from server" + ); + let msg = &actual_response.results[0]; assert_eq!(msg.keys.first(), Some(&"first".to_owned())); assert_eq!(msg.value, "hello".as_bytes()); + drop(tx); shutdown_tx .send(()) .expect("Sending shutdown signal to gRPC server"); @@ -440,11 +657,11 @@ mod tests { #[tokio::test] async fn map_server_panic() -> Result<(), Box> { - struct PanicCat; + struct PanicMapper; #[tonic::async_trait] - impl map::Mapper for PanicCat { - async fn map(&self, _input: map::MapRequest) -> Vec { - panic!("PanicCat panicking"); + impl map::Mapper for PanicMapper { + async fn map(&self, _: map::MapRequest) -> Vec { + panic!("Panic in mapper"); } } @@ -452,7 +669,7 @@ mod tests { let sock_file = tmp_dir.path().join("map.sock"); let server_info_file = tmp_dir.path().join("mapper-server-info"); - let mut server = map::Server::new(PanicCat) + let mut server = map::Server::new(PanicMapper) .with_server_info_file(&server_info_file) .with_socket_file(&sock_file) .with_max_message_size(10240); @@ -466,8 +683,10 @@ mod tests { tokio::time::sleep(Duration::from_millis(50)).await; + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? .connect_with_connector(service_fn(move |_: Uri| { + // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes let sock_file = sock_file.clone(); async move { Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( @@ -478,22 +697,41 @@ mod tests { .await?; let mut client = MapClient::new(channel); - let request = tonic::Request::new(map::proto::MapRequest { - keys: vec!["first".into(), "second".into()], - value: "hello".into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - headers: Default::default(), - }); - - // server should return an error because of the panic. - let resp = client.map_fn(request).await; - assert!(resp.is_err(), "Expected error from server"); - - if let Err(e) = resp { - assert_eq!(e.code(), tonic::Code::Internal); - assert!(e.message().contains("User Defined Error")); - } + + let (tx, rx) = mpsc::channel(2); + let handshake_request = proto::MapRequest { + request: None, + id: "".to_string(), + handshake: Some(proto::Handshake { sot: true }), + }; + tx.send(handshake_request).await.unwrap(); + + let mut stream = tokio::time::timeout( + Duration::from_secs(2), + client.map_fn(ReceiverStream::new(rx)), + ) + .await + .map_err(|_| "timeout while getting stream for map_fn")?? + .into_inner(); + + let handshake_resp = stream.message().await?.unwrap(); + assert!( + handshake_resp.handshake.is_some(), + "Not a valid response for handshake request" + ); + + let request = proto::MapRequest { + request: Some(proto::map_request::Request { + keys: vec!["three".into(), "four".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }), + id: "".to_string(), + handshake: None, + }; + tx.send(request).await.unwrap(); // server should shut down gracefully because there was a panic in the handler. for _ in 0..10 { @@ -511,17 +749,11 @@ mod tests { // should shut down gracefully. #[tokio::test] async fn panic_with_multiple_requests() -> Result<(), Box> { - struct PanicCat; + struct PanicMapper; #[tonic::async_trait] - impl map::Mapper for PanicCat { - async fn map(&self, input: map::MapRequest) -> Vec { - if !input.keys.is_empty() && input.keys[0] == "key1" { - sleep(Duration::from_millis(20)).await; - panic!("Cat panicked"); - } - // assume each request takes 100ms to process - sleep(Duration::from_millis(100)).await; - vec![] + impl map::Mapper for PanicMapper { + async fn map(&self, _: map::MapRequest) -> Vec { + panic!("Panic in mapper"); } } @@ -529,7 +761,7 @@ mod tests { let sock_file = tmp_dir.path().join("map.sock"); let server_info_file = tmp_dir.path().join("mapper-server-info"); - let mut server = map::Server::new(PanicCat) + let mut server = map::Server::new(PanicMapper) .with_server_info_file(&server_info_file) .with_socket_file(&sock_file) .with_max_message_size(10240); @@ -543,6 +775,7 @@ mod tests { tokio::time::sleep(Duration::from_millis(50)).await; + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? .connect_with_connector(service_fn(move |_: Uri| { let sock_file = sock_file.clone(); @@ -556,47 +789,48 @@ mod tests { let mut client = MapClient::new(channel); - let mut client_one = client.clone(); - tokio::spawn(async move { - let request = tonic::Request::new(map::proto::MapRequest { - keys: vec!["key2".into()], + let (tx, rx) = mpsc::channel(2); + let handshake_request = proto::MapRequest { + request: None, + id: "".to_string(), + handshake: Some(proto::Handshake { sot: true }), + }; + tx.send(handshake_request).await.unwrap(); + + let mut stream = tokio::time::timeout( + Duration::from_secs(2), + client.map_fn(ReceiverStream::new(rx)), + ) + .await + .map_err(|_| "timeout while getting stream for map_fn")?? + .into_inner(); + + let handshake_resp = stream.message().await?.unwrap(); + assert!( + handshake_resp.handshake.is_some(), + "Not a valid response for handshake request" + ); + + let request = proto::MapRequest { + request: Some(proto::map_request::Request { + keys: vec!["five".into(), "six".into()], value: "hello".into(), watermark: Some(prost_types::Timestamp::default()), event_time: Some(prost_types::Timestamp::default()), headers: Default::default(), - }); - - // panic is only for requests with key "key1", since we have graceful shutdown - // the request should get processed. - let resp = client_one.map_fn(request).await; - assert!(resp.is_ok(), "Expected ok from server"); - }); - - let request = tonic::Request::new(map::proto::MapRequest { - keys: vec!["key1".into()], - value: "hello".into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - headers: Default::default(), - }); - - // panic happens for the key1 request, so we should expect error on the client side. - let resp = client.map_fn(request).await; - assert!(resp.is_err(), "Expected error from server"); - - if let Err(e) = resp { - assert_eq!(e.code(), tonic::Code::Internal); - assert!(e.message().contains("User Defined Error")); - } + }), + id: "".to_string(), + handshake: None, + }; + tx.send(request).await.unwrap(); - // but since there is a panic, the server should shutdown. + // server should shut down gracefully because there was a panic in the handler. for _ in 0..10 { tokio::time::sleep(Duration::from_millis(10)).await; if task.is_finished() { break; } } - assert!(task.is_finished(), "gRPC server is still running"); Ok(()) } diff --git a/src/shared.rs b/src/shared.rs index 2754507..c5380e6 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -91,7 +91,7 @@ impl ServerInfo { minimum_numaflow_version: MINIMUM_NUMAFLOW_VERSION .get(&container_type) .map(|&version| version.to_string()) - .unwrap_or_else(String::new), + .unwrap_or_default(), version: SDK_VERSION.to_string(), metadata: Option::from(metadata), } @@ -138,6 +138,7 @@ pub(crate) fn prost_timestamp_from_utc(t: DateTime) -> Option { /// shuts downs the gRPC server. This happens in 2 cases /// 1. there has been an internal error (one of the tasks failed) and we need to shutdown /// 2. user is explicitly asking us to shutdown +/// /// Once the request for shutdown has be invoked, server will broadcast shutdown to all tasks /// through the cancellation-token. pub(crate) async fn shutdown_signal( diff --git a/src/sourcetransform.rs b/src/sourcetransform.rs index 9d8fe04..b3ad454 100644 --- a/src/sourcetransform.rs +++ b/src/sourcetransform.rs @@ -255,26 +255,7 @@ where mpsc::channel::>(DEFAULT_CHANNEL_SIZE); // do the handshake first to let the client know that we are ready to receive transformation requests. - let handshake_request = stream - .message() - .await - .map_err(|e| Status::internal(format!("handshake failed {}", e)))? - .ok_or_else(|| Status::internal("stream closed before handshake"))?; - - if let Some(handshake) = handshake_request.handshake { - stream_response_tx - .send(Ok(SourceTransformResponse { - results: vec![], - id: "".to_string(), - handshake: Some(handshake), - })) - .await - .map_err(|e| { - Status::internal(format!("failed to send handshake response {}", e)) - })?; - } else { - return Err(Status::invalid_argument("Handshake not present")); - } + perform_handshake(&mut stream, &stream_response_tx).await?; let (error_tx, error_rx) = mpsc::channel::(1); @@ -288,7 +269,7 @@ where self.cancellation_token.child_token(), )); - tokio::spawn(manage_gprc_stream( + tokio::spawn(manage_grpc_stream( handle, self.cancellation_token.clone(), stream_response_tx, @@ -304,8 +285,33 @@ where } } +async fn perform_handshake( + stream: &mut Streaming, + stream_response_tx: &mpsc::Sender>, +) -> Result<(), Status> { + let handshake_request = stream + .message() + .await + .map_err(|e| Status::internal(format!("Handshake failed: {}", e)))? + .ok_or_else(|| Status::internal("Stream closed before handshake"))?; + + if let Some(handshake) = handshake_request.handshake { + stream_response_tx + .send(Ok(SourceTransformResponse { + results: vec![], + id: "".to_string(), + handshake: Some(handshake), + })) + .await + .map_err(|e| Status::internal(format!("Failed to send handshake response: {}", e)))?; + Ok(()) + } else { + Err(Status::invalid_argument("Handshake not present")) + } +} + // shutdown the gRPC server on first error -async fn manage_gprc_stream( +async fn manage_grpc_stream( request_handler: JoinHandle<()>, token: CancellationToken, stream_response_tx: mpsc::Sender>, @@ -335,8 +341,8 @@ async fn manage_gprc_stream( .expect("Writing to shutdown channel"); } -// Receives messages from the stream. -// For each message received from the stream, a new task is spawned to call the transform function and send the response back to the client +// Receives messages from the stream. For each message received from the stream, +// a new task is spawned to call the transform function and send the response back to the client async fn handle_stream_requests( handler: Arc, mut stream: Streaming, @@ -673,7 +679,7 @@ mod tests { "Not a valid response for handshake request" ); - let request = sourcetransform::proto::SourceTransformRequest { + let request = proto::SourceTransformRequest { request: Some(proto::source_transform_request::Request { id: "1".to_string(), keys: vec!["first".into(), "second".into()], @@ -772,7 +778,7 @@ mod tests { let request = proto::SourceTransformRequest { request: Some(proto::source_transform_request::Request { - id: "1".to_string(), + id: "2".to_string(), keys: vec!["first".into(), "second".into()], value: "hello".into(), watermark: Some(prost_types::Timestamp::default()), From 19ab43ec7fc97b14992d8ec45269d313388ccd8b Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Tue, 8 Oct 2024 13:33:56 +0530 Subject: [PATCH 2/2] format imports, review comments Signed-off-by: Yashash H L --- src/batchmap.rs | 3 ++- src/map.rs | 17 ++++++++++------- src/shared.rs | 11 +++++++---- src/sideinput.rs | 8 +++++--- src/sink.rs | 15 ++++++++------- src/source.rs | 18 ++++++++++-------- 6 files changed, 42 insertions(+), 30 deletions(-) diff --git a/src/batchmap.rs b/src/batchmap.rs index 9814a59..4112618 100644 --- a/src/batchmap.rs +++ b/src/batchmap.rs @@ -1,9 +1,10 @@ -use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::fs; use std::path::PathBuf; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; + +use chrono::{DateTime, Utc}; use tokio::sync::mpsc::channel; use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; diff --git a/src/map.rs b/src/map.rs index 2ba1fca..a12340e 100644 --- a/src/map.rs +++ b/src/map.rs @@ -1,11 +1,9 @@ -use crate::error::{Error, ErrorKind}; -use crate::map::proto::MapResponse; -use crate::shared::{self, shutdown_signal, ContainerType}; -use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::fs; use std::path::PathBuf; use std::sync::Arc; + +use chrono::{DateTime, Utc}; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; @@ -13,6 +11,10 @@ use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status, Streaming}; use tracing::{error, info}; +use crate::error::{Error, ErrorKind}; +use crate::map::proto::MapResponse; +use crate::shared::{self, shutdown_signal, ContainerType}; + const DEFAULT_CHANNEL_SIZE: usize = 1000; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/map.sock"; @@ -544,11 +546,8 @@ impl Drop for Server { #[cfg(test)] mod tests { - use crate::map; - use crate::map::proto::map_client::MapClient; use std::{error::Error, time::Duration}; - use crate::map::proto; use tempfile::TempDir; use tokio::net::UnixStream; use tokio::sync::{mpsc, oneshot}; @@ -556,6 +555,10 @@ mod tests { use tonic::transport::Uri; use tower::service_fn; + use crate::map; + use crate::map::proto; + use crate::map::proto::map_client::MapClient; + #[tokio::test] async fn map_server() -> Result<(), Box> { struct Cat; diff --git a/src/shared.rs b/src/shared.rs index c5380e6..b22f585 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -1,10 +1,11 @@ -use chrono::{DateTime, TimeZone, Timelike, Utc}; -use prost_types::Timestamp; -use serde::{Deserialize, Serialize}; use std::fs; use std::path::Path; use std::sync::LazyLock; use std::{collections::HashMap, io}; + +use chrono::{DateTime, TimeZone, Timelike, Utc}; +use prost_types::Timestamp; +use serde::{Deserialize, Serialize}; use tokio::net::UnixListener; use tokio::signal; use tokio::sync::{mpsc, oneshot}; @@ -178,11 +179,13 @@ pub(crate) async fn shutdown_signal( #[cfg(test)] mod tests { - use super::*; use std::fs::File; use std::io::Read; + use tempfile::NamedTempFile; + use super::*; + #[test] fn test_utc_from_timestamp() { let specific_date = Utc.with_ymd_and_hms(2022, 7, 2, 2, 0, 0).unwrap(); diff --git a/src/sideinput.rs b/src/sideinput.rs index c5e8c0e..93a6bd0 100644 --- a/src/sideinput.rs +++ b/src/sideinput.rs @@ -1,13 +1,15 @@ -use crate::error::Error::SideInputError; -use crate::error::ErrorKind::{InternalError, UserDefinedError}; -use crate::shared::{self, shutdown_signal, ContainerType}; use std::fs; use std::path::PathBuf; use std::sync::Arc; + use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status}; +use crate::error::Error::SideInputError; +use crate::error::ErrorKind::{InternalError, UserDefinedError}; +use crate::shared::{self, shutdown_signal, ContainerType}; + const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sideinput.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sideinput-server-info"; diff --git a/src/sink.rs b/src/sink.rs index 824d0af..329a3dc 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -1,14 +1,9 @@ -use crate::error::Error; -use crate::error::Error::SinkError; -use crate::error::ErrorKind::{InternalError, UserDefinedError}; -use crate::shared::{self, ContainerType}; -use crate::sink::sink_pb::SinkResponse; - -use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::{env, fs}; + +use chrono::{DateTime, Utc}; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; @@ -16,6 +11,12 @@ use tokio_util::sync::CancellationToken; use tonic::{Request, Status, Streaming}; use tracing::{debug, info}; +use crate::error::Error; +use crate::error::Error::SinkError; +use crate::error::ErrorKind::{InternalError, UserDefinedError}; +use crate::shared::{self, ContainerType}; +use crate::sink::sink_pb::SinkResponse; + const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sink.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sinker-server-info"; diff --git a/src/source.rs b/src/source.rs index c50894a..6a67fbc 100644 --- a/src/source.rs +++ b/src/source.rs @@ -4,11 +4,6 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; -use crate::error::Error::SourceError; -use crate::error::{Error, ErrorKind}; -use crate::shared::{self, prost_timestamp_from_utc, ContainerType}; -use crate::source::proto::{AckRequest, AckResponse, ReadRequest, ReadResponse}; - use chrono::{DateTime, Utc}; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::sync::oneshot; @@ -18,6 +13,11 @@ use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status, Streaming}; use tracing::{error, info}; +use crate::error::Error::SourceError; +use crate::error::{Error, ErrorKind}; +use crate::shared::{self, prost_timestamp_from_utc, ContainerType}; +use crate::source::proto::{AckRequest, AckResponse, ReadRequest, ReadResponse}; + const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/source.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcer-server-info"; @@ -543,13 +543,12 @@ impl Drop for Server { #[cfg(test)] mod tests { - use super::{proto, Message, Offset, SourceReadRequest}; - use crate::source; - use chrono::Utc; use std::collections::{HashMap, HashSet}; use std::error::Error; use std::time::Duration; use std::vec; + + use chrono::Utc; use tempfile::TempDir; use tokio::net::UnixStream; use tokio::sync::mpsc::Sender; @@ -560,6 +559,9 @@ mod tests { use tower::service_fn; use uuid::Uuid; + use super::{proto, Message, Offset, SourceReadRequest}; + use crate::source; + // A source that repeats the `num` for the requested count struct Repeater { num: usize,