diff --git a/akka-persistence-rs-commitlog/src/lib.rs b/akka-persistence-rs-commitlog/src/lib.rs index ca07951..31cb6a9 100644 --- a/akka-persistence-rs-commitlog/src/lib.rs +++ b/akka-persistence-rs-commitlog/src/lib.rs @@ -37,8 +37,7 @@ where fn secret_path(&self, entity_id: &EntityId) -> Arc; #[cfg(feature = "cbor")] - async fn record(&self, mut record: ConsumerRecord) -> Option> { - let entity_id = Self::to_entity_id(&record)?; + async fn record(&self, entity_id: EntityId, mut record: ConsumerRecord) -> Option> { streambed::decrypt_buf( self.secret_store(), &self.secret_path(&entity_id), @@ -50,7 +49,7 @@ where } #[cfg(not(feature = "cbor"))] - async fn record(&self, record: ConsumerRecord) -> Option>; + async fn record(&self, entity_id: EntityId, record: ConsumerRecord) -> Option>; #[cfg(feature = "cbor")] async fn producer_record( @@ -112,7 +111,7 @@ where { commit_log: CL, consumer_group_name: String, - marshaller: M, + marshaler: M, topic: Topic, phantom: PhantomData, } @@ -123,11 +122,11 @@ where M: CommitLogRecordMarshaler, for<'async_trait> E: DeserializeOwned + Serialize + Send + Sync + 'async_trait, { - pub fn new(commit_log: CL, marshaller: M, consumer_group_name: &str, topic: TopicRef) -> Self { + pub fn new(commit_log: CL, marshaler: M, consumer_group_name: &str, topic: TopicRef) -> Self { Self { commit_log, consumer_group_name: consumer_group_name.into(), - marshaller, + marshaler, topic: topic.into(), phantom: PhantomData, } @@ -144,38 +143,25 @@ where async fn produce_initial( &mut self, ) -> io::Result> + Send + 'async_trait>>> { - let last_offset = self - .commit_log - .offsets(self.topic.clone(), 0) - .await - .map(|lo| lo.end_offset); - - if let Some(last_offset) = last_offset { - let subscriptions = vec![Subscription { - topic: self.topic.clone(), - }]; - - let mut records = self.commit_log.scoped_subscribe( - &self.consumer_group_name, - vec![], - subscriptions, - None, - ); + let consumer_records = produce_to_last_offset( + &self.commit_log, + &self.consumer_group_name, + self.topic.clone(), + ) + .await; - let marshaller = &self.marshaller; + let marshaler = &self.marshaler; + if let Ok(mut consumer_records) = consumer_records { Ok(Box::pin(stream!({ - while let Some(record) = records.next().await { - if record.offset <= last_offset { - let is_last_offset = record.offset == last_offset; - if let Some(record) = marshaller.record(record).await { + while let Some(consumer_record) = consumer_records.next().await { + if let Some(record_entity_id) = M::to_entity_id(&consumer_record) { + if let Some(record) = + marshaler.record(record_entity_id, consumer_record).await + { yield record; - if !is_last_offset { - continue; - } } } - break; } }))) } else { @@ -187,12 +173,26 @@ where &mut self, entity_id: &EntityId, ) -> io::Result> + Send + 'async_trait>>> { - let records = self.produce_initial().await; - if let Ok(mut records) = records { + let consumer_records = produce_to_last_offset( + &self.commit_log, + &self.consumer_group_name, + self.topic.clone(), + ) + .await; + + let marshaler = &self.marshaler; + + if let Ok(mut consumer_records) = consumer_records { Ok(Box::pin(stream!({ - while let Some(record) = records.next().await { - if &record.entity_id == entity_id { - yield record; + while let Some(consumer_record) = consumer_records.next().await { + if let Some(record_entity_id) = M::to_entity_id(&consumer_record) { + if &record_entity_id == entity_id { + if let Some(record) = + marshaler.record(record_entity_id, consumer_record).await + { + yield record; + } + } } } }))) @@ -203,7 +203,7 @@ where async fn process(&mut self, record: Record) -> io::Result> { let (producer_record, record) = self - .marshaller + .marshaler .producer_record(self.topic.clone(), record) .await .ok_or_else(|| { @@ -225,6 +225,39 @@ where } } +async fn produce_to_last_offset<'async_trait>( + commit_log: &'async_trait impl CommitLog, + consumer_group_name: &str, + topic: Topic, +) -> io::Result + Send + 'async_trait>>> { + let last_offset = commit_log + .offsets(topic.clone(), 0) + .await + .map(|lo| lo.end_offset); + + if let Some(last_offset) = last_offset { + let subscriptions = vec![Subscription { topic }]; + + let mut records = + commit_log.scoped_subscribe(consumer_group_name, vec![], subscriptions, None); + + Ok(Box::pin(stream!({ + while let Some(record) = records.next().await { + if record.offset <= last_offset { + let is_last_offset = record.offset == last_offset; + yield record; + if !is_last_offset { + continue; + } + } + break; + } + }))) + } else { + Ok(Box::pin(tokio_stream::empty())) + } +} + #[cfg(test)] mod tests { use std::{env, fs, num::NonZeroUsize, time::Duration}; @@ -340,8 +373,12 @@ mod tests { panic!("should not be called") } - fn to_entity_id(_record: &ConsumerRecord) -> Option { - panic!("should not be called") + fn to_entity_id(record: &ConsumerRecord) -> Option { + let Header { value, .. } = record + .headers + .iter() + .find(|header| header.key == "entity-id")?; + std::str::from_utf8(value).ok().map(EntityId::from) } fn secret_store(&self) -> &Self::SecretStore { @@ -352,12 +389,11 @@ mod tests { panic!("should not be called") } - async fn record(&self, record: ConsumerRecord) -> Option> { - let Header { value, .. } = record - .headers - .into_iter() - .find(|header| header.key == "entity-id")?; - let entity_id = EntityId::from(std::str::from_utf8(&value).ok()?); + async fn record( + &self, + entity_id: EntityId, + record: ConsumerRecord, + ) -> Option> { let value = String::from_utf8(record.value).ok()?; let event = MyEvent { value }; Some(Record { @@ -403,10 +439,10 @@ mod tests { let commit_log = FileLog::new(logged_dir); - let marshaller = MyEventMarshaler; + let marshaler = MyEventMarshaler; let mut adapter = CommitLogTopicAdapter::new( commit_log.clone(), - marshaller, + marshaler, "some-consumer", "some-topic", ); @@ -492,10 +528,10 @@ mod tests { async fn can_establish_an_entity_manager() { let commit_log = FileLog::new("/dev/null"); - let marshaller = MyEventMarshaler; + let marshaler = MyEventMarshaler; let file_log_topic_adapter = - CommitLogTopicAdapter::new(commit_log, marshaller, "some-consumer", "some-topic"); + CommitLogTopicAdapter::new(commit_log, marshaler, "some-consumer", "some-topic"); let my_behavior = MyBehavior;