diff --git a/akka-projection-rs-grpc/src/consumer.rs b/akka-projection-rs-grpc/src/consumer.rs index 3d6a510..66015eb 100644 --- a/akka-projection-rs-grpc/src/consumer.rs +++ b/akka-projection-rs-grpc/src/consumer.rs @@ -18,6 +18,7 @@ use prost_types::Timestamp; use std::{future::Future, marker::PhantomData, ops::Range, pin::Pin}; use tokio::sync::mpsc; use tokio::sync::oneshot; +use tokio::sync::watch; use tokio_stream::Stream; use tokio_stream::StreamExt; use tonic::transport::Channel; @@ -29,9 +30,9 @@ use crate::EventEnvelope; use crate::StreamId; pub struct GrpcSourceProvider { + consumer_filters: Option>>, delayer: Option, event_producer_channel: EP, - initial_consumer_filter: Vec, offset_store: mpsc::Sender>, slice_range: Range, stream_id: StreamId, @@ -65,9 +66,9 @@ where slice_range: Range, ) -> Self { Self { + consumer_filters: None, delayer: None, event_producer_channel, - initial_consumer_filter: vec![], offset_store, slice_range, stream_id, @@ -75,14 +76,11 @@ where } } - pub fn with_initial_consumer_filter( + pub fn with_consumer_filters( mut self, - initial_consumer_filter: Vec, + consumer_filters: watch::Receiver>, ) -> Self { - self.initial_consumer_filter = initial_consumer_filter - .into_iter() - .map(|f| f.into()) - .collect(); + self.consumer_filters = Some(consumer_filters); self } } @@ -142,6 +140,28 @@ where } }); + let stream_consumer_filters = self.consumer_filters.as_ref().cloned(); + + let consumer_filters = stream! { + if let Some(mut consumer_filters) = stream_consumer_filters { + while consumer_filters.changed().await.is_ok() { + let criteria: Vec = consumer_filters + .borrow() + .clone() + .into_iter() + .map(|c| c.into()) + .collect(); + yield proto::StreamIn { + message: Some(proto::stream_in::Message::Filter(proto::FilterReq { + criteria, + })), + }; + } + } else { + futures::future::pending::<()>().await; + } + }; + let request = Request::new( tokio_stream::iter(vec![proto::StreamIn { message: Some(proto::stream_in::Message::Init(proto::InitReq { @@ -149,10 +169,20 @@ where slice_min: self.slice_range.start as i32, slice_max: self.slice_range.end as i32 - 1, offset, - filter: self.initial_consumer_filter.clone(), + filter: self + .consumer_filters + .as_ref() + .map_or(vec![], |consumer_filters| { + consumer_filters + .borrow() + .clone() + .into_iter() + .map(|c| c.into()) + .collect() + }), })), }]) - .chain(tokio_stream::pending()), + .chain(consumer_filters), ); let result = connection.events_by_slices(request).await; @@ -267,7 +297,7 @@ mod tests { use super::*; use akka_persistence_rs::{EntityId, EntityType, PersistenceId}; - use akka_projection_rs::consumer_filter::EntityIdOffset; + use akka_projection_rs::consumer_filter::{self, EntityIdOffset}; use async_stream::stream; use chrono::{DateTime, Utc}; use prost_types::Any; @@ -289,8 +319,36 @@ mod tests { async fn events_by_slices( &self, - _request: Request>, + request: Request>, ) -> Result, Status> { + let mut inner = request.into_inner(); + + let Some(Ok(proto::StreamIn { + message: Some(proto::stream_in::Message::Init(proto::InitReq { filter, .. })), + })) = inner.next().await + else { + return Err(Status::aborted("Expected the initial request")); + }; + + if filter.is_empty() { + return Err(Status::aborted( + "Expected the initial request to have a filter", + )); + } + + let Some(Ok(proto::StreamIn { + message: Some(proto::stream_in::Message::Filter(proto::FilterReq { criteria, .. })), + })) = inner.next().await + else { + return Err(Status::aborted("Expected the criteria to be updated")); + }; + + if criteria.is_empty() { + return Err(Status::aborted( + "Expected the filter request to have a filter", + )); + } + let stream_event_time = self.event_time; let stream_event_seen_by = self.event_seen_by.clone(); Ok(Response::new(Box::pin(stream!({ @@ -432,18 +490,25 @@ mod tests { } }); + let (consumer_filters, consumer_filters_receiver) = + watch::channel(vec![FilterCriteria::IncludeEntityIds { + entity_id_offsets: vec![EntityIdOffset { + entity_id: entity_id.clone(), + seq_nr: 0, + }], + }]); + let channel = Channel::from_static("http://127.0.0.1:50051"); let mut source_provider = GrpcSourceProvider::::new( || channel.connect(), StreamId::from("some-string-id"), offset_store, ) - .with_initial_consumer_filter(vec![FilterCriteria::IncludeEntityIds { - entity_id_offsets: vec![EntityIdOffset { - entity_id: entity_id.clone(), - seq_nr: 0, - }], - }]); + .with_consumer_filters(consumer_filters_receiver); + + assert!(consumer_filters + .send(vec![consumer_filter::exclude_all()]) + .is_ok()); let mut tried = 0; diff --git a/akka-projection-rs/src/consumer_filter.rs b/akka-projection-rs/src/consumer_filter.rs index 6b7a47e..c6214bc 100644 --- a/akka-projection-rs/src/consumer_filter.rs +++ b/akka-projection-rs/src/consumer_filter.rs @@ -28,6 +28,7 @@ use akka_persistence_rs::EntityId; use smol_str::SmolStr; +#[derive(Clone)] pub struct EntityIdOffset { pub entity_id: EntityId, // If this is defined (> 0) events are replayed from the given @@ -48,6 +49,7 @@ pub type TopicMatcher = SmolStr; /// If an exclude criteria is matching the include criteria are evaluated. /// If no matching include criteria the event is discarded. /// If matching include criteria the event is emitted. +#[derive(Clone)] pub enum FilterCriteria { /// Exclude events with any of the given tags, unless there is a /// matching include filter that overrides the exclude.