From 9644d9e76688c6c5320a9b03bbcd727a89b57fcd Mon Sep 17 00:00:00 2001 From: Vigith Maurice Date: Wed, 11 Sep 2024 19:24:48 -0700 Subject: [PATCH 1/7] feat: streaming source Signed-off-by: Vigith Maurice --- proto/source.proto | 48 ++++++++++++++++++++------- src/source.rs | 83 +++++----------------------------------------- 2 files changed, 44 insertions(+), 87 deletions(-) diff --git a/proto/source.proto b/proto/source.proto index 4352028..dcaf253 100644 --- a/proto/source.proto +++ b/proto/source.proto @@ -7,16 +7,17 @@ package source.v1; service Source { // Read returns a stream of datum responses. - // The size of the returned ReadResponse is less than or equal to the num_records specified in ReadRequest. - // If the request timeout is reached on server side, the returned ReadResponse will contain all the datum that have been read (which could be an empty list). - rpc ReadFn(ReadRequest) returns (stream ReadResponse); + // The size of the returned ReadResponse is less than or equal to the num_records specified in each ReadRequest. + // If the request timeout is reached on the server side, the returned ReadResponse will contain all the datum that have been read (which could be an empty list). + // The server will continue to read and respond to subsequent ReadRequests until the client closes the stream. + rpc ReadFn(stream ReadRequest) returns (stream ReadResponse); - // AckFn acknowledges a list of datum offsets. + // AckFn acknowledges a stream of datum offsets. // When AckFn is called, it implicitly indicates that the datum stream has been processed by the source vertex. // The caller (numa) expects the AckFn to be successful, and it does not expect any errors. // If there are some irrecoverable errors when the callee (UDSource) is processing the AckFn request, - // then it is best to crash because there are no other retry mechanisms possible. - rpc AckFn(AckRequest) returns (AckResponse); + // then it is best to crash because there are no other retry mechanisms possible. + rpc AckFn(stream AckRequest) returns (AckResponse); // PendingFn returns the number of pending records at the user defined source. rpc PendingFn(google.protobuf.Empty) returns (PendingResponse); @@ -60,9 +61,35 @@ message ReadResponse { // We add this optional field to support the use case where the user defined source can provide keys for the datum. // e.g. Kafka and Redis Stream message usually include information about the keys. repeated string keys = 4; + // Optional list of headers associated with the datum. + // Headers are the metadata associated with the datum. + // e.g. Kafka and Redis Stream message usually include information about the headers. + map headers = 5; + } + message Status { + // Code to indicate the status of the response. + enum Code { + SUCCESS = 0; + FAILURE = 1; + } + + // Error to indicate the error type. If the code is FAILURE, then the error field will be populated. + enum Error { + UNACKED = 0; + OTHER = 1; + } + + // End of transmission flag. + bool eot = 1; + Code code = 2; + Error error = 3; + optional string msg = 4; } // Required field holding the result. Result result = 1; + // Status of the response. Holds the end of transmission flag and the status code. + // + Status status = 2; } /* @@ -71,11 +98,8 @@ message ReadResponse { */ message AckRequest { message Request { - // Required field holding a list of offsets to be acknowledged. - // The offsets must be strictly corresponding to the previously read batch, - // meaning the offsets must be in the same order as the datum responses in the ReadResponse. - // By enforcing ordering, we can save deserialization effort on the server side, assuming the server keeps a local copy of the raw/un-serialized offsets. - repeated Offset offsets = 1; + // Required field holding the offset to be acked + Offset offset = 1; } // Required field holding the request. The list will be ordered and will have the same order as the original Read response. Request request = 1; @@ -146,4 +170,4 @@ message Offset { // It is useful for sources that have multiple partitions. e.g. Kafka. // If the partition_id is not specified, it is assumed that the source has a single partition. int32 partition_id = 2; -} \ No newline at end of file +} diff --git a/src/source.rs b/src/source.rs index af3cf3e..15dae3b 100644 --- a/src/source.rs +++ b/src/source.rs @@ -10,7 +10,8 @@ use tokio::sync::mpsc::{self, Sender}; use tokio::sync::oneshot; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; -use tonic::{async_trait, Request, Response, Status}; +use tonic::{async_trait, Request, Response, Status, Streaming}; +use crate::source::proto::{AckRequest, AckResponse, ReadRequest}; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/source.sock"; @@ -27,6 +28,7 @@ struct SourceService { _cancellation_token: CancellationToken, } +// FIXME: remove async_trait #[async_trait] /// Trait representing a [user defined source](https://numaflow.numaproj.io/user-guide/sources/overview/). /// @@ -80,84 +82,15 @@ where { type ReadFnStream = ReceiverStream>; - async fn read_fn( - &self, - request: Request, - ) -> Result, Status> { - let sr = request.into_inner().request.unwrap(); - - // tx,rx pair for sending data over to user-defined source - let (stx, mut srx) = mpsc::channel::(sr.num_records as usize); - // tx,rx pair for gRPC response - let (tx, rx) = - mpsc::channel::>(sr.num_records as usize); - - // start the ud-source rx asynchronously and start populating the gRPC response, so it can be streamed to the gRPC client (numaflow). - tokio::spawn(async move { - while let Some(resp) = srx.recv().await { - tx.send(Ok(proto::ReadResponse { - result: Some(proto::read_response::Result { - payload: resp.value, - offset: Some(proto::Offset { - offset: resp.offset.offset, - partition_id: resp.offset.partition_id, - }), - event_time: prost_timestamp_from_utc(resp.event_time), - keys: resp.keys, - }), - })) - .await - .expect("receiver dropped"); - } - }); - - let handler_fn = Arc::clone(&self.handler); - // we want to start streaming to the server as soon as possible - tokio::spawn(async move { - // user-defined source read handler - handler_fn - .read( - SourceReadRequest { - count: sr.num_records as usize, - timeout: Duration::from_millis(sr.timeout_in_ms as u64), - }, - stx, - ) - .await - }); - - Ok(Response::new(ReceiverStream::new(rx))) + async fn read_fn(&self, request: Request>) -> Result, Status> { + todo!() } - async fn ack_fn( - &self, - request: Request, - ) -> Result, Status> { - let ar: proto::AckRequest = request.into_inner(); - - let success_response = Response::new(proto::AckResponse { - result: Some(proto::ack_response::Result { success: Some(()) }), - }); - - let Some(request) = ar.request else { - return Ok(success_response); - }; - - // invoke the user-defined source's ack handler - let offsets = request - .offsets - .into_iter() - .map(|so| Offset { - offset: so.offset, - partition_id: so.partition_id, - }) - .collect(); - - self.handler.ack(offsets).await; - - Ok(success_response) + async fn ack_fn(&self, request: Request>) -> Result, Status> { + todo!() } + async fn pending_fn(&self, _: Request<()>) -> Result, Status> { // invoke the user-defined source's pending handler let pending = self.handler.pending().await; From 2e9496e6bff2be4712904fc3aa4017ffa403f2a1 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Thu, 12 Sep 2024 10:36:41 +0530 Subject: [PATCH 2/7] source bidirectional streaming Signed-off-by: Yashash H L --- src/source.rs | 119 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 111 insertions(+), 8 deletions(-) diff --git a/src/source.rs b/src/source.rs index 15dae3b..c4ce627 100644 --- a/src/source.rs +++ b/src/source.rs @@ -4,18 +4,23 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use crate::error::Error::SourceError; +use crate::error::{Error, ErrorKind}; use crate::shared::{self, prost_timestamp_from_utc}; +use crate::source::proto::{AckRequest, AckResponse, ReadRequest}; use chrono::{DateTime, Utc}; use tokio::sync::mpsc::{self, Sender}; use tokio::sync::oneshot; +use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status, Streaming}; -use crate::source::proto::{AckRequest, AckResponse, ReadRequest}; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/source.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcer-server-info"; +// TODO: use batch-size, blocked by https://github.com/numaproj/numaflow/issues/2026 +const DEFAULT_CHANNEL_SIZE: usize = 1000; /// Source Proto definitions. pub mod proto { @@ -48,8 +53,8 @@ struct SourceService { pub trait Sourcer { /// Reads the messages from the source and sends them to the transmitter. async fn read(&self, request: SourceReadRequest, transmitter: Sender); - /// Acknowledges the messages that have been processed by the user-defined source. - async fn ack(&self, offsets: Vec); + /// Acknowledges the message that has been processed by the user-defined source. + async fn ack(&self, offset: Offset); /// Returns the number of messages that are yet to be processed by the user-defined source. async fn pending(&self) -> usize; /// Returns the partitions associated with the source. This will be used by the platform to determine @@ -82,14 +87,112 @@ where { type ReadFnStream = ReceiverStream>; - async fn read_fn(&self, request: Request>) -> Result, Status> { - todo!() - } + async fn read_fn( + &self, + request: Request>, + ) -> Result, Status> { + let mut sr = request.into_inner(); + + // tx,rx pair for gRPC response + let (tx, rx) = mpsc::channel::>(DEFAULT_CHANNEL_SIZE); + + let handler_fn = Arc::clone(&self.handler); + + let grpc_read_handle: JoinHandle> = tokio::spawn(async move { + while let Some(read_request) = sr + .message() + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? + { + // tx,rx pair for sending data over to user-defined source + let (stx, mut srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); + + let Some(request) = read_request.request else { + panic!("request cannot be empty"); + }; + + let grpc_resp_tx = tx.clone(); + // start the ud-source rx asynchronously and start populating the gRPC response, so it can be streamed to the gRPC client (numaflow). + let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { + while let Some(resp) = srx.recv().await { + grpc_resp_tx + .send(Ok(proto::ReadResponse { + result: Some(proto::read_response::Result { + payload: resp.value, + offset: Some(proto::Offset { + offset: resp.offset.offset, + partition_id: resp.offset.partition_id, + }), + event_time: prost_timestamp_from_utc(resp.event_time), + keys: resp.keys, + headers: Default::default(), + }), + status: None, + })) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + } + + grpc_resp_tx + .send(Ok(proto::ReadResponse { + result: None, + status: Some(proto::read_response::Status { + eot: true, + code: 0, + error: 0, + msg: None, + }), + })) + .await + .map_err(|e| Error::SourceError(ErrorKind::InternalError(e.to_string())))?; + + Ok(()) + }); + + handler_fn + .read( + SourceReadRequest { + count: request.num_records as usize, + timeout: Duration::from_millis(request.timeout_in_ms as u64), + }, + stx, + ) + .await; - async fn ack_fn(&self, request: Request>) -> Result, Status> { - todo!() + grpc_writer_handle + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + } + Ok(()) + }); + + // we want to start streaming to the server as soon as possible + tokio::spawn(async move { + // user-defined source read handler + }); + + Ok(Response::new(ReceiverStream::new(rx))) } + async fn ack_fn( + &self, + request: Request>, + ) -> Result, Status> { + let mut acks = request.into_inner(); + while let Some(ack_request) = acks.message().await? { + let offset = ack_request.request.unwrap().offset; + self.handler + .ack(Offset { + offset: offset.clone().unwrap().offset, + partition_id: offset.unwrap().partition_id, + }) + .await; + } + Ok(Response::new(AckResponse { + result: Some(proto::ack_response::Result { success: None }), + })) + } async fn pending_fn(&self, _: Request<()>) -> Result, Status> { // invoke the user-defined source's pending handler From dcbb26834153b84853d9757e25395d92a1314d4a Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Thu, 12 Sep 2024 17:55:57 +0530 Subject: [PATCH 3/7] bidirectional source Signed-off-by: Yashash H L --- examples/simple-source/src/main.rs | 8 +- src/source.rs | 324 +++++++++++++++-------------- 2 files changed, 172 insertions(+), 160 deletions(-) diff --git a/examples/simple-source/src/main.rs b/examples/simple-source/src/main.rs index e1328bd..9127211 100644 --- a/examples/simple-source/src/main.rs +++ b/examples/simple-source/src/main.rs @@ -60,11 +60,9 @@ pub(crate) mod simple_source { self.yet_to_ack.write().unwrap().extend(message_offsets) } - async fn ack(&self, offsets: Vec) { - for offset in offsets { - let x = &String::from_utf8(offset.offset).unwrap(); - self.yet_to_ack.write().unwrap().remove(x); - } + async fn ack(&self, offset: Offset) { + let x = &String::from_utf8(offset.offset).unwrap(); + self.yet_to_ack.write().unwrap().remove(x); } async fn pending(&self) -> usize { diff --git a/src/source.rs b/src/source.rs index c4ce627..9196c37 100644 --- a/src/source.rs +++ b/src/source.rs @@ -29,8 +29,8 @@ pub mod proto { struct SourceService { handler: Arc, - _shutdown_tx: Sender<()>, - _cancellation_token: CancellationToken, + shutdown_tx: Sender<()>, + cancellation_token: CancellationToken, } // FIXME: remove async_trait @@ -98,78 +98,99 @@ where let handler_fn = Arc::clone(&self.handler); + let grpc_tx = tx.clone(); + let cln_token = self.cancellation_token.clone(); let grpc_read_handle: JoinHandle> = tokio::spawn(async move { - while let Some(read_request) = sr - .message() - .await - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? - { - // tx,rx pair for sending data over to user-defined source - let (stx, mut srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); - - let Some(request) = read_request.request else { - panic!("request cannot be empty"); - }; - - let grpc_resp_tx = tx.clone(); - // start the ud-source rx asynchronously and start populating the gRPC response, so it can be streamed to the gRPC client (numaflow). - let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { - while let Some(resp) = srx.recv().await { - grpc_resp_tx - .send(Ok(proto::ReadResponse { - result: Some(proto::read_response::Result { - payload: resp.value, - offset: Some(proto::Offset { - offset: resp.offset.offset, - partition_id: resp.offset.partition_id, + loop { + tokio::select! { + read_request = sr.message() => { + let read_request = read_request + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? + .ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; + + // tx,rx pair for sending data over to user-defined source + let (stx, mut srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); + + let Some(request) = read_request.request else { + panic!("request cannot be empty"); + }; + + let grpc_resp_tx = grpc_tx.clone(); + // start the ud-source rx asynchronously and start populating the gRPC response, so it can be streamed to the gRPC client (numaflow). + let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { + while let Some(resp) = srx.recv().await { + grpc_resp_tx + .send(Ok(proto::ReadResponse { + result: Some(proto::read_response::Result { + payload: resp.value, + offset: Some(proto::Offset { + offset: resp.offset.offset, + partition_id: resp.offset.partition_id, + }), + event_time: prost_timestamp_from_utc(resp.event_time), + keys: resp.keys, + headers: Default::default(), + }), + status: None, + })) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + } + + grpc_resp_tx + .send(Ok(proto::ReadResponse { + result: None, + status: Some(proto::read_response::Status { + eot: true, + code: 0, + error: 0, + msg: None, }), - event_time: prost_timestamp_from_utc(resp.event_time), - keys: resp.keys, - headers: Default::default(), - }), - status: None, - })) + })) + .await + .map_err(|e| Error::SourceError(ErrorKind::InternalError(e.to_string())))?; + + Ok(()) + }); + + handler_fn + .read( + SourceReadRequest { + count: request.num_records as usize, + timeout: Duration::from_millis(request.timeout_in_ms as u64), + }, + stx, + ) + .await; + + grpc_writer_handle .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; } - - grpc_resp_tx - .send(Ok(proto::ReadResponse { - result: None, - status: Some(proto::read_response::Status { - eot: true, - code: 0, - error: 0, - msg: None, - }), - })) - .await - .map_err(|e| Error::SourceError(ErrorKind::InternalError(e.to_string())))?; - - Ok(()) - }); - - handler_fn - .read( - SourceReadRequest { - count: request.num_records as usize, - timeout: Duration::from_millis(request.timeout_in_ms as u64), - }, - stx, - ) - .await; - - grpc_writer_handle - .await - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + _ = cln_token.cancelled() => { + eprintln!("Cancellation token triggered, shutting down"); + break; + } + } } Ok(()) }); - // we want to start streaming to the server as soon as possible + let shutdown_tx = self.shutdown_tx.clone(); tokio::spawn(async move { - // user-defined source read handler + // wait for grpc read handle, if there are any errors write to the grpc response channel + if let Err(e) = grpc_read_handle.await { + tx.send(Err(Status::internal(e.to_string()))) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string()))) + .expect("writing error to grpc response channel should never fail"); + + shutdown_tx + .send(()) + .await + .expect("write to shutdown channel should never fail"); + } }); Ok(Response::new(ReceiverStream::new(rx))) @@ -179,18 +200,26 @@ where &self, request: Request>, ) -> Result, Status> { - let mut acks = request.into_inner(); - while let Some(ack_request) = acks.message().await? { - let offset = ack_request.request.unwrap().offset; + let mut ack_stream = request.into_inner(); + while let Some(ack_request) = ack_stream.message().await? { + // the request is not there send back status as invalid argument + let Some(request) = ack_request.request else { + return Err(Status::invalid_argument("request is empty")); + }; + + let Some(offset) = request.offset else { + return Err(Status::invalid_argument("offset is not present")); + }; + self.handler .ack(Offset { - offset: offset.clone().unwrap().offset, - partition_id: offset.unwrap().partition_id, + offset: offset.clone().offset, + partition_id: offset.partition_id, }) .await; } Ok(Response::new(AckResponse { - result: Some(proto::ack_response::Result { success: None }), + result: Some(proto::ack_response::Result { success: Some(()) }), })) } @@ -312,8 +341,8 @@ impl Server { let source_service = SourceService { handler: Arc::new(handler), - _shutdown_tx: internal_shutdown_tx, - _cancellation_token: cln_token.clone(), + shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), }; let source_svc = proto::source_server::SourceServer::new(source_service) @@ -354,18 +383,19 @@ impl Drop for Server { #[cfg(test)] mod tests { use super::{proto, Message, Offset, SourceReadRequest}; + use crate::source; use chrono::Utc; use std::collections::{HashMap, HashSet}; + use std::error::Error; + use std::time::Duration; use std::vec; - use std::{error::Error, time::Duration}; - - use crate::source; use tempfile::TempDir; use tokio::net::UnixStream; use tokio::sync::mpsc::Sender; - use tokio::sync::oneshot; - use tokio_stream::StreamExt; + use tokio::sync::{mpsc, oneshot}; + use tokio_stream::wrappers::ReceiverStream; use tonic::transport::Uri; + use tonic::Request; use tower::service_fn; use uuid::Uuid; @@ -413,13 +443,11 @@ mod tests { self.yet_to_ack.write().unwrap().extend(message_offsets) } - async fn ack(&self, offsets: Vec) { - for offset in offsets { - self.yet_to_ack - .write() - .unwrap() - .remove(&String::from_utf8(offset.offset).unwrap()); - } + async fn ack(&self, offset: Offset) { + self.yet_to_ack + .write() + .unwrap() + .remove(&String::from_utf8(offset.offset).unwrap()); } async fn pending(&self) -> usize { @@ -469,83 +497,69 @@ mod tests { .await?; let mut client = proto::source_client::SourceClient::new(channel); - let request = tonic::Request::new(proto::ReadRequest { + + // Test read_fn with bidirectional streaming + let (read_tx, read_rx) = mpsc::channel(4); + let read_request = proto::ReadRequest { request: Some(proto::read_request::Request { num_records: 5, - timeout_in_ms: 500, + timeout_in_ms: 1000, }), - }); + }; + read_tx.send(read_request).await.unwrap(); + drop(read_tx); // Close the sender to indicate no more requests - let resp = client.read_fn(request).await?; - let resp = resp.into_inner(); - let result: Vec = resp - .map(|item| item.unwrap().result.unwrap()) - .collect() - .await; - let response_values: Vec = result - .iter() - .map(|item| { - usize::from_le_bytes( - item.payload - .clone() - .try_into() - .expect("expected Vec length to be 8"), - ) - }) - .collect(); - assert_eq!(response_values, vec![8, 8, 8, 8, 8]); - - let pending_before_ack = client - .pending_fn(tonic::Request::new(())) - .await - .unwrap() - .into_inner(); - assert_eq!( - pending_before_ack.result.unwrap().count, - 5, - "Expected pending messages to be 5 before ACK" - ); - - let offsets_to_ack: Vec = result - .iter() - .map(|item| item.clone().offset.unwrap()) - .collect(); - let ack_request = tonic::Request::new(proto::AckRequest { - request: Some(proto::ack_request::Request { - offsets: offsets_to_ack, - }), - }); - let resp = client.ack_fn(ack_request).await.unwrap().into_inner(); - assert!( - resp.result.unwrap().success.is_some(), - "Expected acknowledgement request to be successful" - ); - - let pending_before_ack = client - .pending_fn(tonic::Request::new(())) - .await - .unwrap() + let mut response_stream = client + .read_fn(Request::new(ReceiverStream::new(read_rx))) + .await? .into_inner(); - assert_eq!( - pending_before_ack.result.unwrap().count, - 0, - "Expected pending messages to be 0 after ACK" - ); - - let partitions = client - .partitions_fn(tonic::Request::new(())) - .await - .unwrap() + let mut response_values = Vec::new(); + + while let Some(response) = response_stream.message().await? { + if let Some(status) = response.status { + if status.eot { + break; + } + } + + if let Some(result) = response.result { + response_values.push(result); + } + } + assert_eq!(response_values.len(), 5); + + // Test pending_fn + let pending_before_ack = client.pending_fn(Request::new(())).await?.into_inner(); + assert_eq!(pending_before_ack.result.unwrap().count, 5); + + // Test ack_fn with client-side streaming + let (ack_tx, ack_rx) = mpsc::channel(10); + for resp in response_values.iter() { + let ack_request = proto::AckRequest { + request: Some(proto::ack_request::Request { + offset: Some(proto::Offset { + offset: resp.offset.clone().unwrap().offset, + partition_id: resp.offset.clone().unwrap().partition_id, + }), + }), + }; + ack_tx.send(ack_request).await.unwrap(); + } + drop(ack_tx); // Close the sender to indicate no more requests + + let ack_response = client + .ack_fn(Request::new(ReceiverStream::new(ack_rx))) + .await? .into_inner(); - assert_eq!( - partitions.result.unwrap().partitions, - vec![2], - "Expected number of partitions to be 2" - ); - - shutdown_tx - .send(()) - .expect("Sending shutdown signal to gRPC server"); + assert!(ack_response.result.unwrap().success.is_some()); + + let pending_after_ack = client.pending_fn(Request::new(())).await?.into_inner(); + assert_eq!(pending_after_ack.result.unwrap().count, 0); + + let partitions = client.partitions_fn(Request::new(())).await?.into_inner(); + assert_eq!(partitions.result.unwrap().partitions, vec![2]); + + shutdown_tx.send(()).unwrap(); tokio::time::sleep(Duration::from_millis(50)).await; assert!(task.is_finished(), "gRPC server is still running"); Ok(()) From c117aa967f50d338202cc3f7e0d22e2df5f3aa96 Mon Sep 17 00:00:00 2001 From: Vigith Maurice Date: Fri, 13 Sep 2024 08:10:17 -0700 Subject: [PATCH 4/7] chore: some comments Signed-off-by: Vigith Maurice --- src/source.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/source.rs b/src/source.rs index 9196c37..8e937f0 100644 --- a/src/source.rs +++ b/src/source.rs @@ -93,12 +93,14 @@ where ) -> Result, Status> { let mut sr = request.into_inner(); - // tx,rx pair for gRPC response + // tx (read from client) ,rx (write to client) pair for gRPC response let (tx, rx) = mpsc::channel::>(DEFAULT_CHANNEL_SIZE); let handler_fn = Arc::clone(&self.handler); + // this _tx ends up writing to the client side let grpc_tx = tx.clone(); + let cln_token = self.cancellation_token.clone(); let grpc_read_handle: JoinHandle> = tokio::spawn(async move { loop { @@ -115,9 +117,13 @@ where panic!("request cannot be empty"); }; + // start the ud-source rx asynchronously and start populating the gRPC + // response, so it can be streamed to the gRPC client (numaflow). let grpc_resp_tx = grpc_tx.clone(); - // start the ud-source rx asynchronously and start populating the gRPC response, so it can be streamed to the gRPC client (numaflow). let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { + // even though we use bi-di; the user-defined source sees this as a 1/2 duplex + // server side streaming. this means that the below while loop will terminate + // after every batch of read has been returned. while let Some(resp) = srx.recv().await { grpc_resp_tx .send(Ok(proto::ReadResponse { @@ -137,6 +143,7 @@ where .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; } + // send end of transmission on success grpc_resp_tx .send(Ok(proto::ReadResponse { result: None, @@ -148,7 +155,7 @@ where }), })) .await - .map_err(|e| Error::SourceError(ErrorKind::InternalError(e.to_string())))?; + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; Ok(()) }); From 205517af31ba4dc9ed674eff4fbc2bd38dda2a81 Mon Sep 17 00:00:00 2001 From: Vigith Maurice Date: Fri, 13 Sep 2024 08:55:36 -0700 Subject: [PATCH 5/7] chorre: fix code readability Signed-off-by: Vigith Maurice --- src/source.rs | 148 +++++++++++++++++++++++++++++--------------------- 1 file changed, 86 insertions(+), 62 deletions(-) diff --git a/src/source.rs b/src/source.rs index 8e937f0..60479c7 100644 --- a/src/source.rs +++ b/src/source.rs @@ -7,14 +7,15 @@ use std::time::Duration; use crate::error::Error::SourceError; use crate::error::{Error, ErrorKind}; use crate::shared::{self, prost_timestamp_from_utc}; -use crate::source::proto::{AckRequest, AckResponse, ReadRequest}; +use crate::source::proto::{AckRequest, AckResponse, ReadRequest, ReadResponse}; use chrono::{DateTime, Utc}; -use tokio::sync::mpsc::{self, Sender}; +use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tonic::{async_trait, Request, Response, Status, Streaming}; +use tracing::{error, info}; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/source.sock"; @@ -80,12 +81,88 @@ pub struct Offset { pub partition_id: i32, } +impl SourceService +where + T: Sourcer + Send + Sync + 'static, +{ + async fn write_a_batch( + grpc_resp_tx: Sender>, + mut srx: Receiver, + ) -> crate::error::Result<()> { + // even though we use bi-di; the user-defined source sees this as a 1/2 duplex + // server side streaming. this means that the below while loop will terminate + // after every batch of read has been returned. + while let Some(resp) = srx.recv().await { + grpc_resp_tx + .send(Ok(ReadResponse { + result: Some(proto::read_response::Result { + payload: resp.value, + offset: Some(proto::Offset { + offset: resp.offset.offset, + partition_id: resp.offset.partition_id, + }), + event_time: prost_timestamp_from_utc(resp.event_time), + keys: resp.keys, + headers: Default::default(), + }), + status: None, + })) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + } + + // send end of transmission on success + grpc_resp_tx + .send(Ok(ReadResponse { + result: None, + status: Some(proto::read_response::Status { + eot: true, + code: 0, + error: 0, + msg: None, + }), + })) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + + Ok(()) + } + + async fn forward_a_batch( + handler_fn: Arc, + grpc_resp_tx: Sender>, + stx: Sender, + srx: Receiver, + request: proto::read_request::Request, + ) -> crate::error::Result<()> { + let grpc_writer_handle: JoinHandle> = + tokio::spawn(async move { Self::write_a_batch(grpc_resp_tx, srx).await }); + + handler_fn + .read( + SourceReadRequest { + count: request.num_records as usize, + timeout: Duration::from_millis(request.timeout_in_ms as u64), + }, + stx, + ) + .await; + + grpc_writer_handle + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + + Ok(()) + } +} + #[async_trait] impl proto::source_server::Source for SourceService where T: Sourcer + Send + Sync + 'static, { - type ReadFnStream = ReceiverStream>; + type ReadFnStream = ReceiverStream>; async fn read_fn( &self, @@ -111,72 +188,18 @@ where .ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; // tx,rx pair for sending data over to user-defined source - let (stx, mut srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); + let (stx, srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); - let Some(request) = read_request.request else { - panic!("request cannot be empty"); - }; + let request = read_request.request.ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; // start the ud-source rx asynchronously and start populating the gRPC // response, so it can be streamed to the gRPC client (numaflow). let grpc_resp_tx = grpc_tx.clone(); - let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { - // even though we use bi-di; the user-defined source sees this as a 1/2 duplex - // server side streaming. this means that the below while loop will terminate - // after every batch of read has been returned. - while let Some(resp) = srx.recv().await { - grpc_resp_tx - .send(Ok(proto::ReadResponse { - result: Some(proto::read_response::Result { - payload: resp.value, - offset: Some(proto::Offset { - offset: resp.offset.offset, - partition_id: resp.offset.partition_id, - }), - event_time: prost_timestamp_from_utc(resp.event_time), - keys: resp.keys, - headers: Default::default(), - }), - status: None, - })) - .await - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; - } - - // send end of transmission on success - grpc_resp_tx - .send(Ok(proto::ReadResponse { - result: None, - status: Some(proto::read_response::Status { - eot: true, - code: 0, - error: 0, - msg: None, - }), - })) - .await - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; - - Ok(()) - }); - - handler_fn - .read( - SourceReadRequest { - count: request.num_records as usize, - timeout: Duration::from_millis(request.timeout_in_ms as u64), - }, - stx, - ) - .await; - - grpc_writer_handle - .await - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + + Self::forward_a_batch(handler_fn.clone(), grpc_resp_tx, stx, srx, request).await? } _ = cln_token.cancelled() => { - eprintln!("Cancellation token triggered, shutting down"); + info!("Cancellation token triggered, shutting down"); break; } } @@ -188,6 +211,7 @@ where tokio::spawn(async move { // wait for grpc read handle, if there are any errors write to the grpc response channel if let Err(e) = grpc_read_handle.await { + error!("shutting down the gRPC channel, {}", e); tx.send(Err(Status::internal(e.to_string()))) .await .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string()))) From e42cce5b6df67d84bcf63095738d9abbeb4ca395 Mon Sep 17 00:00:00 2001 From: Vigith Maurice Date: Fri, 13 Sep 2024 09:55:44 -0700 Subject: [PATCH 6/7] chore: add crate::error::Result Signed-off-by: Vigith Maurice --- src/error.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/error.rs b/src/error.rs index e33102f..5a3818a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,7 @@ use thiserror::Error; +pub type Result = std::result::Result; + #[derive(Error, Debug, Clone)] pub enum ErrorKind { #[error("User Defined Error: {0}")] From 709e9ee115a3cc212ed612559f0ac92bd866c428 Mon Sep 17 00:00:00 2001 From: Vigith Maurice Date: Fri, 13 Sep 2024 17:33:55 -0700 Subject: [PATCH 7/7] doc: some comments Signed-off-by: Vigith Maurice --- src/source.rs | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/source.rs b/src/source.rs index 60479c7..94cc841 100644 --- a/src/source.rs +++ b/src/source.rs @@ -85,14 +85,15 @@ impl SourceService where T: Sourcer + Send + Sync + 'static, { + /// writes a read batch returned by the user-defined handler to the client (numaflow). async fn write_a_batch( grpc_resp_tx: Sender>, - mut srx: Receiver, + mut udsource_rx: Receiver, ) -> crate::error::Result<()> { // even though we use bi-di; the user-defined source sees this as a 1/2 duplex // server side streaming. this means that the below while loop will terminate // after every batch of read has been returned. - while let Some(resp) = srx.recv().await { + while let Some(resp) = udsource_rx.recv().await { grpc_resp_tx .send(Ok(ReadResponse { result: Some(proto::read_response::Result { @@ -128,16 +129,23 @@ where Ok(()) } + /// Invokes the user-defined source handler to get a read batch and streams it to the numaflow + /// (client). async fn forward_a_batch( handler_fn: Arc, grpc_resp_tx: Sender>, - stx: Sender, - srx: Receiver, request: proto::read_request::Request, ) -> crate::error::Result<()> { + // tx,rx pair for sending data over to user-defined source + let (stx, srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); + + // spawn the rx side so that when the handler is invoked, we can stream the handler's read data + // to the gprc response stream. let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { Self::write_a_batch(grpc_resp_tx, srx).await }); + // spawn the handler, it will stream the data to tx passed which will be streamed to the client + // by the above task. handler_fn .read( SourceReadRequest { @@ -148,6 +156,7 @@ where ) .await; + // wait for the spawned grpc writer to end grpc_writer_handle .await .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? @@ -169,34 +178,37 @@ where request: Request>, ) -> Result, Status> { let mut sr = request.into_inner(); - - // tx (read from client) ,rx (write to client) pair for gRPC response - let (tx, rx) = mpsc::channel::>(DEFAULT_CHANNEL_SIZE); - + // we have to call the handler over and over for each ReadRequest let handler_fn = Arc::clone(&self.handler); + // tx (read from client), rx (write to client) pair for gRPC response + let (tx, rx) = mpsc::channel::>(DEFAULT_CHANNEL_SIZE); + // this _tx ends up writing to the client side let grpc_tx = tx.clone(); let cln_token = self.cancellation_token.clone(); + + // this is the top-level stream consumer and this task will only exit when stream is closed (which + // will happen when server and client are shutting down). let grpc_read_handle: JoinHandle> = tokio::spawn(async move { loop { tokio::select! { + // for each ReadRequest message, the handler will be called and a batch of messages + // will be sent over to the client. read_request = sr.message() => { let read_request = read_request .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? .ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; - // tx,rx pair for sending data over to user-defined source - let (stx, srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); - let request = read_request.request.ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; // start the ud-source rx asynchronously and start populating the gRPC // response, so it can be streamed to the gRPC client (numaflow). let grpc_resp_tx = grpc_tx.clone(); - Self::forward_a_batch(handler_fn.clone(), grpc_resp_tx, stx, srx, request).await? + // let's forward a batch for this request + Self::forward_a_batch(handler_fn.clone(), grpc_resp_tx, request).await? } _ = cln_token.cancelled() => { info!("Cancellation token triggered, shutting down"); @@ -208,8 +220,10 @@ where }); let shutdown_tx = self.shutdown_tx.clone(); + // spawn so we can return the recv stream to client. tokio::spawn(async move { - // wait for grpc read handle, if there are any errors write to the grpc response channel + // wait for the grpc read handle; if there are any errors, we set the gRPC Status to failure + // which will close the stream with failure. if let Err(e) = grpc_read_handle.await { error!("shutting down the gRPC channel, {}", e); tx.send(Err(Status::internal(e.to_string()))) @@ -217,6 +231,7 @@ where .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string()))) .expect("writing error to grpc response channel should never fail"); + // if there are any failures, we propagate those failures so that the server can shutdown. shutdown_tx .send(()) .await