From 8832fe805f3d298f42c606ba2a372a1beeeed921 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:21:10 +0100 Subject: [PATCH] refactor(string): use custom iterator to avoid allocation --- tfhe/src/strings/char_iter.rs | 175 ++++++++++++++++++ tfhe/src/strings/mod.rs | 1 + tfhe/src/strings/server_key/mod.rs | 2 - .../strings/server_key/pattern/contains.rs | 13 +- tfhe/src/strings/server_key/pattern/find.rs | 15 +- tfhe/src/strings/server_key/pattern/mod.rs | 47 ++--- tfhe/src/strings/server_key/pattern/strip.rs | 7 +- 7 files changed, 207 insertions(+), 53 deletions(-) create mode 100644 tfhe/src/strings/char_iter.rs diff --git a/tfhe/src/strings/char_iter.rs b/tfhe/src/strings/char_iter.rs new file mode 100644 index 0000000000..7ea508df5c --- /dev/null +++ b/tfhe/src/strings/char_iter.rs @@ -0,0 +1,175 @@ +use super::ciphertext::FheAsciiChar; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; + +pub(super) type CharIter<'a> = OptionalEndSliceIter<'a, FheAsciiChar>; + +pub(super) struct OptionalEndSliceIter<'a, T> { + slice: &'a [T], + last: Option<&'a T>, +} + +impl<'a, T> Clone for OptionalEndSliceIter<'a, T> { + fn clone(&self) -> Self { + *self + } +} + +impl<'a, T> Copy for OptionalEndSliceIter<'a, T> {} + +impl<'a, T> OptionalEndSliceIter<'a, T> { + pub(super) fn len(&self) -> usize { + self.slice.len() + if self.last.is_some() { 1 } else { 0 } + } + + pub(super) fn new(slice: &'a [T], last: Option<&'a T>) -> Self { + Self { slice, last } + } +} + +pub mod iter { + use super::*; + + impl<'a, T> IntoIterator for OptionalEndSliceIter<'a, T> { + type Item = &'a T; + + type IntoIter = OptionalEndSliceIterator<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + OptionalEndSliceIterator { + slice_iter: self.slice.iter(), + last: self.last, + } + } + } + pub struct OptionalEndSliceIterator<'a, T> { + slice_iter: std::slice::Iter<'a, T>, + last: Option<&'a T>, + } + + impl<'a, T> Iterator for OptionalEndSliceIterator<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if let Some(item) = self.slice_iter.next() { + Some(item) + } else { + self.last.take() + } + } + + fn size_hint(&self) -> (usize, Option) { + (self.len(), Some(self.len())) + } + } + + impl<'a, T> DoubleEndedIterator for OptionalEndSliceIterator<'a, T> { + fn next_back(&mut self) -> Option { + if let Some(last) = self.last.take() { + Some(last) + } else { + self.slice_iter.next_back() + } + } + } + + impl<'a, T> ExactSizeIterator for OptionalEndSliceIterator<'a, T> { + fn len(&self) -> usize { + self.slice_iter.len() + if self.last.is_some() { 1 } else { 0 } + } + } + + #[test] + fn test_iter() { + { + let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4)); + + let mut b = a.into_iter(); + + assert_eq!(b.next(), Some(&0)); + assert_eq!(b.next(), Some(&1)); + assert_eq!(b.next(), Some(&2)); + assert_eq!(b.next(), Some(&3)); + assert_eq!(b.next(), Some(&4)); + assert_eq!(b.next(), None); + } + { + let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None); + + let mut b = a.into_iter(); + + assert_eq!(b.next(), Some(&0)); + assert_eq!(b.next(), Some(&1)); + assert_eq!(b.next(), Some(&2)); + assert_eq!(b.next(), Some(&3)); + assert_eq!(b.next(), None); + } + } + + #[test] + fn test_iter_back() { + { + let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4)); + + let mut b = a.into_iter(); + + assert_eq!(b.next_back(), Some(&4)); + assert_eq!(b.next_back(), Some(&3)); + assert_eq!(b.next_back(), Some(&2)); + assert_eq!(b.next_back(), Some(&1)); + assert_eq!(b.next_back(), Some(&0)); + assert_eq!(b.next_back(), None); + } + + { + let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None); + + let mut b = a.into_iter(); + + assert_eq!(b.next_back(), Some(&3)); + assert_eq!(b.next_back(), Some(&2)); + assert_eq!(b.next_back(), Some(&1)); + assert_eq!(b.next_back(), Some(&0)); + assert_eq!(b.next_back(), None); + } + } + + #[test] + fn test_iter_mix() { + { + let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4)); + + let mut b = a.into_iter(); + + assert_eq!(b.next_back(), Some(&4)); + assert_eq!(b.next(), Some(&0)); + assert_eq!(b.next_back(), Some(&3)); + assert_eq!(b.next(), Some(&1)); + assert_eq!(b.next_back(), Some(&2)); + assert_eq!(b.next(), None); + } + { + let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None); + + let mut b = a.into_iter(); + + assert_eq!(b.next_back(), Some(&3)); + assert_eq!(b.next(), Some(&0)); + assert_eq!(b.next_back(), Some(&2)); + assert_eq!(b.next(), Some(&1)); + assert_eq!(b.next_back(), None); + } + } +} + +impl<'a> IntoParallelRefIterator<'a> for CharIter<'a> { + type Item = &'a FheAsciiChar; + + type Iter = rayon::iter::Chain< + rayon::slice::Iter<'a, FheAsciiChar>, + rayon::option::IntoIter<&'a FheAsciiChar>, + >; + + fn par_iter(&'a self) -> Self::Iter { + self.slice.par_iter().chain(self.last.into_par_iter()) + } +} diff --git a/tfhe/src/strings/mod.rs b/tfhe/src/strings/mod.rs index 69070d7999..40128712d7 100644 --- a/tfhe/src/strings/mod.rs +++ b/tfhe/src/strings/mod.rs @@ -4,6 +4,7 @@ pub mod server_key; #[cfg(test)] mod assert_functions; +mod char_iter; // Used as the const argument for StaticUnsignedBigInt, specifying the max chars length of a // ClearString diff --git a/tfhe/src/strings/server_key/mod.rs b/tfhe/src/strings/server_key/mod.rs index 14a18ae896..f6807aa150 100644 --- a/tfhe/src/strings/server_key/mod.rs +++ b/tfhe/src/strings/server_key/mod.rs @@ -277,5 +277,3 @@ impl ServerKey { pub trait FheStringIterator { fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock); } - -type CharIter<'a> = Vec<&'a FheAsciiChar>; diff --git a/tfhe/src/strings/server_key/pattern/contains.rs b/tfhe/src/strings/server_key/pattern/contains.rs index a04a288ae7..f2d68a9114 100644 --- a/tfhe/src/strings/server_key/pattern/contains.rs +++ b/tfhe/src/strings/server_key/pattern/contains.rs @@ -1,7 +1,8 @@ use super::{clear_ends_with_cases, contains_cases, ends_with_cases}; use crate::integer::{BooleanBlock, IntegerRadixCiphertext, RadixCiphertext}; +use crate::strings::char_iter::CharIter; use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPattern}; -use crate::strings::server_key::pattern::{CharIter, IsMatch}; +use crate::strings::server_key::pattern::IsMatch; use crate::strings::server_key::ServerKey; use itertools::Itertools; use rayon::prelude::*; @@ -20,15 +21,11 @@ impl ServerKey { let matched: Vec<_> = par_iter .map(|start| { if ignore_pat_pad { - let str_chars = str - .par_iter() - .copied() - .skip(start) - .zip(pat.par_iter().copied()); + let str_chars = str.par_iter().skip(start).zip(pat.par_iter()); self.asciis_eq_ignore_pat_pad(str_chars) } else { - self.asciis_eq(str.iter().copied().skip(start), pat.iter().copied()) + self.asciis_eq(str.into_iter().skip(start), pat.into_iter()) } }) .collect(); @@ -55,7 +52,7 @@ impl ServerKey { let (str, pat) = str_pat; let matched: Vec<_> = par_iter - .map(|start| self.clear_asciis_eq(str.iter().skip(start).copied(), pat)) + .map(|start| self.clear_asciis_eq(str.into_iter().skip(start), pat)) .collect(); let block_vec: Vec<_> = matched diff --git a/tfhe/src/strings/server_key/pattern/find.rs b/tfhe/src/strings/server_key/pattern/find.rs index a807f5c301..db86cbe9e8 100644 --- a/tfhe/src/strings/server_key/pattern/find.rs +++ b/tfhe/src/strings/server_key/pattern/find.rs @@ -1,9 +1,10 @@ use super::contains_cases; use crate::integer::prelude::*; use crate::integer::{BooleanBlock, RadixCiphertext}; +use crate::strings::char_iter::CharIter; use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPattern}; use crate::strings::server_key::pattern::IsMatch; -use crate::strings::server_key::{CharIter, FheStringIsEmpty, FheStringLen, ServerKey}; +use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey}; use rayon::prelude::*; use rayon::vec::IntoIter; @@ -24,15 +25,11 @@ impl ServerKey { let matched: Vec<_> = par_iter .map(|start| { let is_matched = if ignore_pat_pad { - let str_pat = str - .par_iter() - .copied() - .skip(start) - .zip(pat.par_iter().copied()); + let str_pat = str.par_iter().skip(start).zip(pat.par_iter()); self.asciis_eq_ignore_pat_pad(str_pat) } else { - self.asciis_eq(str.iter().skip(start).copied(), pat.iter().copied()) + self.asciis_eq(str.into_iter().skip(start), pat.into_iter()) }; (start, is_matched) @@ -66,9 +63,7 @@ impl ServerKey { let matched: Vec<_> = par_iter .map(|start| { - let str_chars = &str[start..]; - - let is_matched = self.clear_asciis_eq(str_chars.iter().copied(), pat); + let is_matched = self.clear_asciis_eq(str.into_iter().skip(start), pat); (start, is_matched) }) diff --git a/tfhe/src/strings/server_key/pattern/mod.rs b/tfhe/src/strings/server_key/pattern/mod.rs index a25f7c1632..55d9243f19 100644 --- a/tfhe/src/strings/server_key/pattern/mod.rs +++ b/tfhe/src/strings/server_key/pattern/mod.rs @@ -5,9 +5,9 @@ mod split; mod strip; use crate::integer::BooleanBlock; +use crate::strings::char_iter::CharIter; use crate::strings::ciphertext::{FheAsciiChar, FheString}; -use crate::strings::server_key::{CharIter, FheStringIsEmpty, ServerKey}; -use itertools::Itertools; +use crate::strings::server_key::{FheStringIsEmpty, ServerKey}; use std::ops::Range; // Useful for handling cases in which we know if there is or there isn't a match just by looking at @@ -73,8 +73,8 @@ fn ends_with_cases<'a>( match (str.is_padded(), pat.is_padded()) { // If neither has padding we just check if pat matches the `pat_len` last chars or str (false, false) => { - str_chars = str.chars().iter().collect_vec(); - pat_chars = pat.chars().iter().collect_vec(); + str_chars = CharIter::new(str.chars(), None); + pat_chars = CharIter::new(pat.chars(), None); let start = str_len - pat_len; @@ -84,27 +84,18 @@ fn ends_with_cases<'a>( // If only str is padded we have to check all the possible padding cases. If str is 3 // chars long, then it could be "xx\0", "x\0\0" or "\0\0\0", where x != '\0' (true, false) => { - str_chars = str.chars()[..str_len - 1].iter().collect_vec(); - pat_chars = pat - .chars() - .iter() - .chain(std::iter::once(null.unwrap())) - .collect_vec(); + str_chars = CharIter::new(&str.chars()[..str_len - 1], None); + pat_chars = CharIter::new(pat.chars(), Some(null.unwrap())); let diff = (str_len - 1) - pat_len; range = 0..diff + 1; } - // If only pat is padded we have to check all the possible padding cases as well // If str = "abc" and pat = "abcd\0", we check if "abc\0" == pat[..4] (false, true) => { - str_chars = str - .chars() - .iter() - .chain(std::iter::once(null.unwrap())) - .collect_vec(); - pat_chars = pat.chars().iter().collect_vec(); + str_chars = CharIter::new(str.chars(), Some(null.unwrap())); + pat_chars = CharIter::new(pat.chars(), None); if pat_len - 1 > str_len { // Pat without last char is longer than str so we check all the str chars @@ -119,8 +110,8 @@ fn ends_with_cases<'a>( } (true, true) => { - str_chars = str.chars().iter().collect_vec(); - pat_chars = pat.chars().iter().collect_vec(); + str_chars = CharIter::new(str.chars(), None); + pat_chars = CharIter::new(pat.chars(), None); range = 0..str_len; } @@ -137,7 +128,7 @@ fn clear_ends_with_cases<'a>( let str_len = str.len(); if str.is_padded() { - let str_chars = str.chars()[..str_len - 1].iter().collect(); + let str_chars = CharIter::new(&str.chars()[..str_len - 1], None); let pat_chars = format!("{pat}\0"); let diff = (str_len - 1) - pat_len; @@ -145,7 +136,7 @@ fn clear_ends_with_cases<'a>( (str_chars, pat_chars, range) } else { - let str_chars = str.chars().iter().collect(); + let str_chars = CharIter::new(str.chars(), None); let start = str_len - pat_len; let range = start..start + 1; @@ -168,24 +159,20 @@ fn contains_cases<'a>( let range; if pat.is_padded() { - pat_chars = pat.chars()[..pat_len - 1].iter().collect_vec(); + pat_chars = CharIter::new(&pat.chars()[..pat_len - 1], None); if str.is_padded() { - str_chars = str.chars().iter().collect_vec(); + str_chars = CharIter::new(str.chars(), None); range = 0..str_len - 1; } else { - str_chars = str - .chars() - .iter() - .chain(std::iter::once(null.unwrap())) - .collect_vec(); + str_chars = CharIter::new(str.chars(), Some(null.unwrap())); range = 0..str_len; } } else { - str_chars = str.chars().iter().collect_vec(); - pat_chars = pat.chars().iter().collect_vec(); + str_chars = CharIter::new(str.chars(), None); + pat_chars = CharIter::new(pat.chars(), None); let diff = (str_len - pat_len) - if str.is_padded() { 1 } else { 0 }; diff --git a/tfhe/src/strings/server_key/pattern/strip.rs b/tfhe/src/strings/server_key/pattern/strip.rs index e99ad2958e..68c8a2df22 100644 --- a/tfhe/src/strings/server_key/pattern/strip.rs +++ b/tfhe/src/strings/server_key/pattern/strip.rs @@ -1,9 +1,10 @@ use super::{clear_ends_with_cases, ends_with_cases}; use crate::integer::prelude::*; use crate::integer::BooleanBlock; +use crate::strings::char_iter::CharIter; use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPattern}; use crate::strings::server_key::pattern::IsMatch; -use crate::strings::server_key::{CharIter, FheStringLen, ServerKey}; +use crate::strings::server_key::{FheStringLen, ServerKey}; use rayon::prelude::*; use std::ops::Range; @@ -21,7 +22,7 @@ impl ServerKey { let str_len = str.len(); for start in iter { - let is_matched = self.asciis_eq(str.iter().copied().skip(start), pat.iter().copied()); + let is_matched = self.asciis_eq(str.into_iter().skip(start), pat.into_iter()); let mut mask = is_matched.clone().into_radix(self.num_ascii_blocks(), self); @@ -62,7 +63,7 @@ impl ServerKey { let pat_len = pat.len(); let str_len = str.len(); for start in iter { - let is_matched = self.clear_asciis_eq(str.iter().copied().skip(start), pat); + let is_matched = self.clear_asciis_eq(str.into_iter().skip(start), pat); let mut mask = is_matched.clone().into_radix(self.num_ascii_blocks(), self);