Skip to content

Commit

Permalink
refactor(string): use custom iterator to avoid allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Oct 31, 2024
1 parent ceac3d8 commit 02b21d4
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 53 deletions.
211 changes: 211 additions & 0 deletions tfhe/src/strings/char_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
use super::ciphertext::FheAsciiChar;
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};

pub(super) type CharIter<'a> = OptionalEndSliceIter<'a, FheAsciiChar>;

pub(super) struct OptionalEndSliceIter<'a, T> {
slice: &'a [T],
last: Option<&'a T>,
}

impl<'a, T> Clone for OptionalEndSliceIter<'a, T> {
fn clone(&self) -> Self {
*self
}
}

impl<'a, T> Copy for OptionalEndSliceIter<'a, T> {}

impl<'a, T> OptionalEndSliceIter<'a, T> {
pub(super) fn len(&self) -> usize {
self.slice.len() + if self.last.is_some() { 1 } else { 0 }
}

pub(super) fn new(slice: &'a [T], last: Option<&'a T>) -> Self {
Self { slice, last }
}
}

pub mod iter {
use super::*;
use std::cmp::Ordering;

impl<'a, T> IntoIterator for OptionalEndSliceIter<'a, T> {
type Item = &'a T;

type IntoIter = OptionalEndSliceIterator<'a, T>;

fn into_iter(self) -> Self::IntoIter {
OptionalEndSliceIterator {
iter: self,
index: 0,
back_index: self.len(),
}
}
}

pub struct OptionalEndSliceIterator<'a, T> {
iter: OptionalEndSliceIter<'a, T>,
index: usize,
back_index: usize,
}

impl<'a, T> Iterator for OptionalEndSliceIterator<'a, T> {
type Item = &'a T;

fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.back_index {
return None;
}

match self.index.cmp(&self.iter.slice.len()) {
Ordering::Less => {
let result = &self.iter.slice[self.index];

self.index += 1;

Some(result)
}
Ordering::Equal => {
if let Some(last) = self.iter.last {
self.index += 1;

Some(last)
} else {
None
}
}
Ordering::Greater => None,
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.iter.len(), Some(self.iter.len()))
}
}

impl<'a, T> DoubleEndedIterator for OptionalEndSliceIterator<'a, T> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.index >= self.back_index {
return None;
}

if self.back_index == 0 {
return None;
}

self.back_index -= 1;

match self.back_index.cmp(&self.iter.slice.len()) {
Ordering::Less => {
let result = &self.iter.slice[self.back_index];

Some(result)
}
Ordering::Equal => Some(self.iter.last.unwrap()),
Ordering::Greater => unreachable!(),
}
}
}

impl<'a, T> ExactSizeIterator for OptionalEndSliceIterator<'a, T> {
fn len(&self) -> usize {
self.back_index - self.index
}
}

#[test]
fn test_iter() {
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4));

let mut b = a.into_iter();

assert_eq!(b.next(), Some(&0));
assert_eq!(b.next(), Some(&1));
assert_eq!(b.next(), Some(&2));
assert_eq!(b.next(), Some(&3));
assert_eq!(b.next(), Some(&4));
assert_eq!(b.next(), None);
}
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None);

let mut b = a.into_iter();

assert_eq!(b.next(), Some(&0));
assert_eq!(b.next(), Some(&1));
assert_eq!(b.next(), Some(&2));
assert_eq!(b.next(), Some(&3));
assert_eq!(b.next(), None);
}
}

#[test]
fn test_iter_back() {
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4));

let mut b = a.into_iter();

assert_eq!(b.next_back(), Some(&4));
assert_eq!(b.next_back(), Some(&3));
assert_eq!(b.next_back(), Some(&2));
assert_eq!(b.next_back(), Some(&1));
assert_eq!(b.next_back(), Some(&0));
assert_eq!(b.next_back(), None);
}

{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None);

let mut b = a.into_iter();

assert_eq!(b.next_back(), Some(&3));
assert_eq!(b.next_back(), Some(&2));
assert_eq!(b.next_back(), Some(&1));
assert_eq!(b.next_back(), Some(&0));
assert_eq!(b.next_back(), None);
}
}

#[test]
fn test_iter_mix() {
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4));

let mut b = a.into_iter();

assert_eq!(b.next_back(), Some(&4));
assert_eq!(b.next(), Some(&0));
assert_eq!(b.next_back(), Some(&3));
assert_eq!(b.next(), Some(&1));
assert_eq!(b.next_back(), Some(&2));
assert_eq!(b.next(), None);
}
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None);

let mut b = a.into_iter();

assert_eq!(b.next_back(), Some(&3));
assert_eq!(b.next(), Some(&0));
assert_eq!(b.next_back(), Some(&2));
assert_eq!(b.next(), Some(&1));
assert_eq!(b.next_back(), None);
}
}
}

impl<'a> IntoParallelRefIterator<'a> for CharIter<'a> {
type Item = &'a FheAsciiChar;

type Iter = rayon::iter::Chain<
rayon::slice::Iter<'a, FheAsciiChar>,
rayon::option::IntoIter<&'a FheAsciiChar>,
>;

fn par_iter(&'a self) -> Self::Iter {
self.slice.par_iter().chain(self.last.into_par_iter())
}
}
1 change: 1 addition & 0 deletions tfhe/src/strings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod server_key;

#[cfg(test)]
mod assert_functions;
mod char_iter;

// Used as the const argument for StaticUnsignedBigInt, specifying the max chars length of a
// ClearString
Expand Down
2 changes: 0 additions & 2 deletions tfhe/src/strings/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,3 @@ impl ServerKey {
pub trait FheStringIterator {
fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock);
}

type CharIter<'a> = Vec<&'a FheAsciiChar>;
13 changes: 5 additions & 8 deletions tfhe/src/strings/server_key/pattern/contains.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::{clear_ends_with_cases, contains_cases, ends_with_cases};
use crate::integer::{BooleanBlock, IntegerRadixCiphertext, RadixCiphertext};
use crate::strings::char_iter::CharIter;
use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPattern};
use crate::strings::server_key::pattern::{CharIter, IsMatch};
use crate::strings::server_key::pattern::IsMatch;
use crate::strings::server_key::ServerKey;
use itertools::Itertools;
use rayon::prelude::*;
Expand All @@ -20,15 +21,11 @@ impl ServerKey {
let matched: Vec<_> = par_iter
.map(|start| {
if ignore_pat_pad {
let str_chars = str
.par_iter()
.copied()
.skip(start)
.zip(pat.par_iter().copied());
let str_chars = str.par_iter().skip(start).zip(pat.par_iter());

self.asciis_eq_ignore_pat_pad(str_chars)
} else {
self.asciis_eq(str.iter().copied().skip(start), pat.iter().copied())
self.asciis_eq(str.into_iter().skip(start), pat.into_iter())
}
})
.collect();
Expand All @@ -55,7 +52,7 @@ impl ServerKey {
let (str, pat) = str_pat;

let matched: Vec<_> = par_iter
.map(|start| self.clear_asciis_eq(str.iter().skip(start).copied(), pat))
.map(|start| self.clear_asciis_eq(str.into_iter().skip(start), pat))
.collect();

let block_vec: Vec<_> = matched
Expand Down
15 changes: 5 additions & 10 deletions tfhe/src/strings/server_key/pattern/find.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use super::contains_cases;
use crate::integer::prelude::*;
use crate::integer::{BooleanBlock, RadixCiphertext};
use crate::strings::char_iter::CharIter;
use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPattern};
use crate::strings::server_key::pattern::IsMatch;
use crate::strings::server_key::{CharIter, FheStringIsEmpty, FheStringLen, ServerKey};
use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey};
use rayon::prelude::*;
use rayon::vec::IntoIter;

Expand All @@ -24,15 +25,11 @@ impl ServerKey {
let matched: Vec<_> = par_iter
.map(|start| {
let is_matched = if ignore_pat_pad {
let str_pat = str
.par_iter()
.copied()
.skip(start)
.zip(pat.par_iter().copied());
let str_pat = str.par_iter().skip(start).zip(pat.par_iter());

self.asciis_eq_ignore_pat_pad(str_pat)
} else {
self.asciis_eq(str.iter().skip(start).copied(), pat.iter().copied())
self.asciis_eq(str.into_iter().skip(start), pat.into_iter())
};

(start, is_matched)
Expand Down Expand Up @@ -66,9 +63,7 @@ impl ServerKey {

let matched: Vec<_> = par_iter
.map(|start| {
let str_chars = &str[start..];

let is_matched = self.clear_asciis_eq(str_chars.iter().copied(), pat);
let is_matched = self.clear_asciis_eq(str.into_iter().skip(start), pat);

(start, is_matched)
})
Expand Down
Loading

0 comments on commit 02b21d4

Please sign in to comment.