Skip to content

Commit

Permalink
Unit tests for the gRPC producer
Browse files Browse the repository at this point in the history
Along with having verified that it works with its JVM counterpart. Simplified the re-issuing of in-flight requests also.
  • Loading branch information
huntc committed Sep 10, 2023
1 parent f226306 commit 34e64c0
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 102 deletions.
1 change: 1 addition & 0 deletions akka-projection-rs-grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ bytes = { workspace = true }
chrono = { workspace = true }
exponential-backoff = { workspace = true }
futures = { workspace = true }
log = { workspace = true }
prost = { workspace = true }
prost-types = { workspace = true }
smol_str = { workspace = true }
Expand Down
223 changes: 121 additions & 102 deletions akka-projection-rs-grpc/src/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use akka_projection_rs::HandlerError;
use akka_projection_rs::PendingHandler;
use async_stream::stream;
use async_trait::async_trait;
use log::debug;
use log::warn;
use prost::Name;
use prost_types::Any;
use std::collections::HashMap;
Expand Down Expand Up @@ -130,6 +132,8 @@ impl<E> GrpcEventProducer<E> {
}
}

type Context<E> = (EventEnvelope<E>, oneshot::Sender<()>);

/// Reliably stream event envelopes to a consumer. Event envelope transmission
/// requests are sent over a channel and have a reply that is completed on the
/// remote consumer's acknowledgement of receipt.
Expand All @@ -145,9 +149,7 @@ pub async fn run<E>(
let mut delayer: Option<Delayer> = None;
let mut connection = None;

let mut in_flight = HashMap::new();

let mut unused_request: Option<(EventEnvelope<E>, oneshot::Sender<()>)> = None;
let mut in_flight: HashMap<PersistenceId, VecDeque<Context<E>>> = HashMap::new();

'outer: loop {
if let Err(oneshot::error::TryRecvError::Closed) = kill_switch.try_recv() {
Expand Down Expand Up @@ -236,87 +238,74 @@ pub async fn run<E>(
if let Ok(response) = result {
let mut stream_outs = response.into_inner();

fn push_in_flight<'a, E>(
in_flight: &'a mut HashMap<PersistenceId, VecDeque<(u64, oneshot::Sender<()>)>>,
envelope: &EventEnvelope<E>,
reply_to: oneshot::Sender<()>,
) -> &'a mut VecDeque<(u64, oneshot::Sender<()>)> {
let contexts = in_flight
.entry(envelope.persistence_id.clone())
.or_default();
contexts.push_back((envelope.seq_nr, reply_to));
contexts
}

if let Some((envelope, reply_to)) = unused_request.take() {
let (inner_reply_to, inner_reply) = oneshot::channel::<()>();
loop {
tokio::select! {
Some(Ok(proto::ConsumeEventOut {
message: Some(message),
})) = stream_outs.next() => match message {
proto::consume_event_out::Message::Start(proto::ConsumerEventStart { .. }) => {
debug!("Starting the protocol");
break;
}
_ => {
warn!("Received a message before starting the protocol - ignoring event");
}
},

let contexts = push_in_flight(&mut in_flight, &envelope, inner_reply_to);
_ = &mut kill_switch => break 'outer,
}
}

if event_in.send(envelope.clone()).is_ok() {
if inner_reply.await.is_ok() {
if reply_to.send(()).is_err() {
break 'outer;
}
} else {
unused_request = Some((envelope, reply_to));
// We could do an extra check here and remove the hashmap entry,
// but I don't think it is worth it as it will get cleared up
// given the next round of successful request/responses.
contexts.pop_back();
continue;
for (_, contexts) in in_flight.iter() {
for (envelope, _) in contexts {
if event_in.send(envelope.clone()).is_err() {
break 'outer;
}
} else {
unused_request = Some((envelope, reply_to));
// See the comment above.
contexts.pop_back();
continue;
}
}

loop {
tokio::select! {
request = envelopes.recv() => {
if let Some((envelope, reply_to)) = request {
let (inner_reply_to, inner_reply) = oneshot::channel();

let contexts = push_in_flight(&mut in_flight, &envelope, inner_reply_to);
let contexts = in_flight
.entry(envelope.persistence_id.clone())
.or_default();
contexts.push_back((envelope.clone(), reply_to));

if event_in.send(envelope.clone()).is_ok() && inner_reply.await.is_ok() {
if reply_to.send(()).is_ok() {
continue;
} else {
break 'outer;
}
if event_in.send(envelope).is_err() {
break 'outer;
}

unused_request = Some((envelope, reply_to));
contexts.pop_back();

break;

} else {
break 'outer;
}
}

Some(Ok(proto::ConsumeEventOut {
message: Some(proto::consume_event_out::Message::Ack(proto::ConsumerEventAck { persistence_id, seq_nr })),
})) = stream_outs.next() => {
if let Ok(persistence_id) = persistence_id.parse() {
if let Some(contexts) = in_flight.get_mut(&persistence_id) {
let seq_nr = seq_nr as u64;
while let Some((expected_seq_nr, reply_to)) = contexts.pop_front() {
if seq_nr == expected_seq_nr && reply_to.send(()).is_ok() {
break;
message: Some(message),
})) = stream_outs.next() => match message {
proto::consume_event_out::Message::Start(proto::ConsumerEventStart { .. }) => {
warn!("Received a protocol start when already started - ignoring");
}
proto::consume_event_out::Message::Ack(proto::ConsumerEventAck { persistence_id, seq_nr }) => {
if let Ok(persistence_id) = persistence_id.parse() {
if let Some(contexts) = in_flight.get_mut(&persistence_id) {
let seq_nr = seq_nr as u64;
while let Some((envelope, reply_to)) = contexts.pop_front() {
if seq_nr == envelope.seq_nr && reply_to.send(()).is_ok() {
break;
}
}
if contexts.is_empty() {
in_flight.remove(&persistence_id);
}
}
if contexts.is_empty() {
in_flight.remove(&persistence_id);
}
} else {
warn!("Received an event but could not parse the persistence id - ignoring event");
}
}
}
},

_ = &mut kill_switch => break 'outer,

Expand Down Expand Up @@ -349,15 +338,60 @@ mod tests {

async fn consume_event(
&self,
_request: Request<Streaming<proto::ConsumeEventIn>>,
request: Request<Streaming<proto::ConsumeEventIn>>,
) -> std::result::Result<tonic::Response<Self::ConsumeEventStream>, tonic::Status> {
todo!()
let mut consume_events_in = request.into_inner();
if let Some(Ok(proto::ConsumeEventIn {
message:
Some(proto::consume_event_in::Message::Init(proto::ConsumerEventInit {
origin_id,
stream_id,
})),
})) = consume_events_in.next().await
{
if origin_id == "some-origin-id" && stream_id == "some-stream-id" {
let consume_events_out = Box::pin(stream! {
yield Ok(proto::ConsumeEventOut {
message: Some(proto::consume_event_out::Message::Start(
proto::ConsumerEventStart {
filter: vec![]
},
)),
});

if let Some(Ok(proto::ConsumeEventIn {
message:
Some(proto::consume_event_in::Message::Event(proto::Event {
persistence_id,
seq_nr,
..
})),
})) = consume_events_in.next().await
{
yield Ok(proto::ConsumeEventOut {
message: Some(proto::consume_event_out::Message::Ack(
proto::ConsumerEventAck {
persistence_id,
seq_nr
},
)),
})
}
});
Ok(tonic::Response::new(consume_events_out))
} else {
Err(tonic::Status::failed_precondition(
"Expecting a certain origin and stream id",
))
}
} else {
Err(tonic::Status::failed_precondition("Expecting init"))
}
}
}

#[ignore]
#[test(tokio::test)]
async fn can_flow() {
async fn can_run() {
let server_kill_switch = Arc::new(Notify::new());

let task_kill_switch = server_kill_switch.clone();
Expand All @@ -369,20 +403,18 @@ mod tests {
),
)
.serve_with_shutdown(
"127.0.0.1:50051".to_socket_addrs().unwrap().next().unwrap(),
"127.0.0.1:50052".to_socket_addrs().unwrap().next().unwrap(),
task_kill_switch.notified(),
)
.await
.unwrap();
});

let mut tried = 0;

let (sender, receiver) = mpsc::channel(10);
let (_task_kill_switch, task_kill_switch_receiver) = oneshot::channel();
tokio::spawn(async move {
let _ = run(
"http://127.0.0.1:50051".parse().unwrap(),
"http://127.0.0.1:50052".parse().unwrap(),
OriginId::from("some-origin-id"),
StreamId::from("some-stream-id"),
receiver,
Expand All @@ -391,40 +423,27 @@ mod tests {
.await;
});

loop {
let (reply, reply_receiver) = oneshot::channel();
assert!(sender
.send((
EventEnvelope {
// FIXME Flesh out these fields
persistence_id: "".parse().unwrap(),
seq_nr: 1,
event: Some(prost_types::Duration {
seconds: 0,
nanos: 0
}),
offset: TimestampOffset {
timestamp: Utc::now(),
seen: vec![]
}
},
reply,
))
.await
.is_ok());

let ack = reply_receiver.await;

tried += 1;

if ack.is_err() && tried < 100 {
continue;
}

assert!(ack.is_ok());
let (reply, reply_receiver) = oneshot::channel();
assert!(sender
.send((
EventEnvelope {
persistence_id: "entity-type|entity-id".parse().unwrap(),
seq_nr: 1,
event: Some(prost_types::Duration {
seconds: 0,
nanos: 0
}),
offset: TimestampOffset {
timestamp: Utc::now(),
seen: vec![]
}
},
reply,
))
.await
.is_ok());

break;
}
assert!(reply_receiver.await.is_ok());

server_kill_switch.notified();
}
Expand Down

0 comments on commit 34e64c0

Please sign in to comment.