Skip to content

Commit

Permalink
add makefile, cancellations
Browse files Browse the repository at this point in the history
Signed-off-by: Sidhant Kohli <sidhant_kohli@intuit.com>
  • Loading branch information
Sidhant Kohli committed Jul 23, 2024
1 parent 2c9ad0b commit 243c3a1
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 46 deletions.
17 changes: 17 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Description: Makefile for Rust projects

# perform a cargo fmt on all directories containing a Cargo.toml file
.PHONY: lint
# find all directories containing Cargo.toml files
DIRS := $(shell find . -type f -name Cargo.toml -exec dirname {} \; | sort -u)
$(info Included directories: $(DIRS))
lint:
@for dir in $(DIRS); do \
echo "Formatting code in $$dir"; \
cargo fmt --all --manifest-path "$$dir/Cargo.toml"; \
done

# run cargo test on the repository root
.PHONY: test
test:
cargo test --workspace
18 changes: 9 additions & 9 deletions examples/batchmap-cat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ struct Cat;
impl batchmap::BatchMapper for Cat {
async fn batchmap(&self, mut input: tokio::sync::mpsc::Receiver<Datum>) -> Vec<BatchResponse> {
let mut responses: Vec<BatchResponse> = Vec::new();
while let Some(datum) = input.recv().await {
let mut response = BatchResponse::from_id(datum.id);
response.append(Message {
keys: Option::from(datum.keys),
value: datum.value,
tags: None,
});
responses.push(response);
}
while let Some(datum) = input.recv().await {
let mut response = BatchResponse::from_id(datum.id);
response.append(Message {
keys: Some(datum.keys),
value: datum.value,
tags: None,
});
responses.push(response);
}
responses
}
}
8 changes: 5 additions & 3 deletions examples/reduce-counter/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {

mod counter {
use numaflow::reduce::{Message, ReduceRequest};
use numaflow::reduce::{Reducer, Metadata};
use numaflow::reduce::{Metadata, Reducer};
use tokio::sync::mpsc::Receiver;
use tonic::async_trait;

Expand Down Expand Up @@ -44,8 +44,10 @@ mod counter {
while input.recv().await.is_some() {
counter += 1;
}
let message = Message::new(counter.to_string().into_bytes()).tags(vec![]).keys(keys.clone());
let message = Message::new(counter.to_string().into_bytes())
.tags(vec![])
.keys(keys.clone());
vec![message]
}
}
}
}
7 changes: 4 additions & 3 deletions examples/sideinput/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use numaflow::sideinput::{self, SideInputer};
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
use numaflow::sideinput::{self, SideInputer};


use tonic::async_trait;

Expand Down Expand Up @@ -37,5 +36,7 @@ impl SideInputer for SideInputHandler {

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
sideinput::Server::new(SideInputHandler::new()).start().await
sideinput::Server::new(SideInputHandler::new())
.start()
.await
}
2 changes: 1 addition & 1 deletion examples/sideinput/udf/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::path::Path;

use notify::{RecursiveMode, Result, Watcher};
use numaflow::map::{Mapper, MapRequest, Message, Server};
use numaflow::map::{MapRequest, Mapper, Message, Server};
use tokio::spawn;
use tonic::async_trait;

Expand Down
67 changes: 37 additions & 30 deletions src/batchmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod proto {
struct BatchMapService<T: BatchMapper> {
handler: T,
_shutdown_tx: mpsc::Sender<()>,
cancellation_token: CancellationToken,
}

/// BatchMapper trait for implementing batch mode user defined function.
Expand Down Expand Up @@ -263,6 +264,7 @@ where
channel::<Result<proto::BatchMapResponse, Status>>(1);

let shutdown_tx = self._shutdown_tx.clone();
let writer_cln_token = self.cancellation_token.clone();

// counter to keep track of the number of messages received
let counter_orig = Arc::new(AtomicUsize::new(0));
Expand All @@ -271,16 +273,26 @@ where
let counter = counter_orig.clone();
// write to the user-defined channel
tokio::spawn(async move {
while let Some(next_message) = stream
.message()
.await
.expect("expected next message from stream")
{
let datum = Datum::from(next_message);
tx.send(datum)
.await
.expect("send be successfully received!");
counter.fetch_add(1, Ordering::Relaxed);
loop {
tokio::select! {
next_message = stream.message() => {
match next_message {
Ok(Some(message)) => {
let datum = Datum::from(message);
if tx.send(datum).await.is_err() {
break;
}
counter.fetch_add(1, Ordering::Relaxed);
},
// If there's an error or the stream ends, break the loop to stop the task.
Ok(None) | Err(_) => break,
}
},
// Listen for cancellation. If triggered, break the loop to stop reading new messages.
_ = writer_cln_token.cancelled() => {
break;
}
}
}
});

Expand Down Expand Up @@ -315,11 +327,9 @@ where
}))
.await;
// if the send fails, return an error status on the streaming endpoint
if send_result.is_err() {
if let Err(e) = send_result {
grpc_response_tx
.send(Err(Status::internal(
send_result.err().unwrap().to_string(),
)))
.send(Err(Status::internal(e.to_string())))
.await
.expect("send to grpc response channel failed");
return;
Expand Down Expand Up @@ -400,22 +410,21 @@ impl<T> crate::batchmap::Server<T> {
shared::create_listener_stream(&self.sock_addr, &self.server_info_file, info)?;
let handler = self.svc.take().unwrap();

let cln_token = CancellationToken::new();

// Create a channel to send shutdown signal to the server to do graceful shutdown in case of non retryable errors.
let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1);
let map_svc = crate::batchmap::BatchMapService {
handler,
_shutdown_tx: internal_shutdown_tx,
cancellation_token: cln_token.clone(),
};

let map_svc = proto::batch_map_server::BatchMapServer::new(map_svc)
.max_encoding_message_size(self.max_message_size)
.max_decoding_message_size(self.max_message_size);

let shutdown = shutdown_signal(
internal_shutdown_rx,
Some(shutdown_rx),
CancellationToken::new(),
);
let shutdown = shutdown_signal(internal_shutdown_rx, Some(shutdown_rx), cln_token);

tonic::transport::Server::builder()
.add_service(map_svc)
Expand Down Expand Up @@ -608,18 +617,16 @@ mod tests {
.await?;
let mut r = resp.into_inner();

let mut error_flag = false;
let Err(server_err) = r.message().await else {
return Err("Expected error from server".into());
};

assert_eq!(server_err.code(), tonic::Code::Internal);
assert!(server_err.message().contains(
"number of responses does not \
match the number of messages received"
));

if let Err(e) = r.message().await {
assert_eq!(e.code(), tonic::Code::Internal);
assert!(e.message().contains(
"number of responses does not \
match the number of messages received"
));
error_flag = true;
}
// Check if the error flag is set
assert!(error_flag, "Expected error from server");
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(task.is_finished(), "gRPC server is still running");
Ok(())
Expand Down

0 comments on commit 243c3a1

Please sign in to comment.