Skip to content

Commit

Permalink
refactor(string): use custom iterator to avoid allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Nov 6, 2024
1 parent a9bfd6e commit 8832fe8
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 53 deletions.
175 changes: 175 additions & 0 deletions tfhe/src/strings/char_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use super::ciphertext::FheAsciiChar;
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};

pub(super) type CharIter<'a> = OptionalEndSliceIter<'a, FheAsciiChar>;

pub(super) struct OptionalEndSliceIter<'a, T> {
slice: &'a [T],
last: Option<&'a T>,
}

impl<'a, T> Clone for OptionalEndSliceIter<'a, T> {
fn clone(&self) -> Self {
*self
}
}

impl<'a, T> Copy for OptionalEndSliceIter<'a, T> {}

impl<'a, T> OptionalEndSliceIter<'a, T> {
pub(super) fn len(&self) -> usize {
self.slice.len() + if self.last.is_some() { 1 } else { 0 }
}

pub(super) fn new(slice: &'a [T], last: Option<&'a T>) -> Self {
Self { slice, last }
}
}

pub mod iter {
use super::*;

impl<'a, T> IntoIterator for OptionalEndSliceIter<'a, T> {
type Item = &'a T;

type IntoIter = OptionalEndSliceIterator<'a, T>;

fn into_iter(self) -> Self::IntoIter {
OptionalEndSliceIterator {
slice_iter: self.slice.iter(),
last: self.last,
}
}
}
pub struct OptionalEndSliceIterator<'a, T> {
slice_iter: std::slice::Iter<'a, T>,
last: Option<&'a T>,
}

impl<'a, T> Iterator for OptionalEndSliceIterator<'a, T> {
type Item = &'a T;

fn next(&mut self) -> Option<Self::Item> {
if let Some(item) = self.slice_iter.next() {
Some(item)
} else {
self.last.take()
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.len(), Some(self.len()))
}
}

impl<'a, T> DoubleEndedIterator for OptionalEndSliceIterator<'a, T> {
fn next_back(&mut self) -> Option<Self::Item> {
if let Some(last) = self.last.take() {
Some(last)
} else {
self.slice_iter.next_back()
}
}
}

impl<'a, T> ExactSizeIterator for OptionalEndSliceIterator<'a, T> {
fn len(&self) -> usize {
self.slice_iter.len() + if self.last.is_some() { 1 } else { 0 }
}
}

#[test]
fn test_iter() {
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4));

let mut b = a.into_iter();

assert_eq!(b.next(), Some(&0));
assert_eq!(b.next(), Some(&1));
assert_eq!(b.next(), Some(&2));
assert_eq!(b.next(), Some(&3));
assert_eq!(b.next(), Some(&4));
assert_eq!(b.next(), None);
}
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None);

let mut b = a.into_iter();

assert_eq!(b.next(), Some(&0));
assert_eq!(b.next(), Some(&1));
assert_eq!(b.next(), Some(&2));
assert_eq!(b.next(), Some(&3));
assert_eq!(b.next(), None);
}
}

#[test]
fn test_iter_back() {
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4));

let mut b = a.into_iter();

assert_eq!(b.next_back(), Some(&4));
assert_eq!(b.next_back(), Some(&3));
assert_eq!(b.next_back(), Some(&2));
assert_eq!(b.next_back(), Some(&1));
assert_eq!(b.next_back(), Some(&0));
assert_eq!(b.next_back(), None);
}

{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None);

let mut b = a.into_iter();

assert_eq!(b.next_back(), Some(&3));
assert_eq!(b.next_back(), Some(&2));
assert_eq!(b.next_back(), Some(&1));
assert_eq!(b.next_back(), Some(&0));
assert_eq!(b.next_back(), None);
}
}

#[test]
fn test_iter_mix() {
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], Some(&4));

let mut b = a.into_iter();

assert_eq!(b.next_back(), Some(&4));
assert_eq!(b.next(), Some(&0));
assert_eq!(b.next_back(), Some(&3));
assert_eq!(b.next(), Some(&1));
assert_eq!(b.next_back(), Some(&2));
assert_eq!(b.next(), None);
}
{
let a = OptionalEndSliceIter::new(&[0, 1, 2, 3], None);

let mut b = a.into_iter();

assert_eq!(b.next_back(), Some(&3));
assert_eq!(b.next(), Some(&0));
assert_eq!(b.next_back(), Some(&2));
assert_eq!(b.next(), Some(&1));
assert_eq!(b.next_back(), None);
}
}
}

impl<'a> IntoParallelRefIterator<'a> for CharIter<'a> {
type Item = &'a FheAsciiChar;

type Iter = rayon::iter::Chain<
rayon::slice::Iter<'a, FheAsciiChar>,
rayon::option::IntoIter<&'a FheAsciiChar>,
>;

fn par_iter(&'a self) -> Self::Iter {
self.slice.par_iter().chain(self.last.into_par_iter())
}
}
1 change: 1 addition & 0 deletions tfhe/src/strings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod server_key;

#[cfg(test)]
mod assert_functions;
mod char_iter;

// Used as the const argument for StaticUnsignedBigInt, specifying the max chars length of a
// ClearString
Expand Down
2 changes: 0 additions & 2 deletions tfhe/src/strings/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,3 @@ impl ServerKey {
pub trait FheStringIterator {
fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock);
}

type CharIter<'a> = Vec<&'a FheAsciiChar>;
13 changes: 5 additions & 8 deletions tfhe/src/strings/server_key/pattern/contains.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::{clear_ends_with_cases, contains_cases, ends_with_cases};
use crate::integer::{BooleanBlock, IntegerRadixCiphertext, RadixCiphertext};
use crate::strings::char_iter::CharIter;
use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPattern};
use crate::strings::server_key::pattern::{CharIter, IsMatch};
use crate::strings::server_key::pattern::IsMatch;
use crate::strings::server_key::ServerKey;
use itertools::Itertools;
use rayon::prelude::*;
Expand All @@ -20,15 +21,11 @@ impl ServerKey {
let matched: Vec<_> = par_iter
.map(|start| {
if ignore_pat_pad {
let str_chars = str
.par_iter()
.copied()
.skip(start)
.zip(pat.par_iter().copied());
let str_chars = str.par_iter().skip(start).zip(pat.par_iter());

self.asciis_eq_ignore_pat_pad(str_chars)
} else {
self.asciis_eq(str.iter().copied().skip(start), pat.iter().copied())
self.asciis_eq(str.into_iter().skip(start), pat.into_iter())
}
})
.collect();
Expand All @@ -55,7 +52,7 @@ impl ServerKey {
let (str, pat) = str_pat;

let matched: Vec<_> = par_iter
.map(|start| self.clear_asciis_eq(str.iter().skip(start).copied(), pat))
.map(|start| self.clear_asciis_eq(str.into_iter().skip(start), pat))
.collect();

let block_vec: Vec<_> = matched
Expand Down
15 changes: 5 additions & 10 deletions tfhe/src/strings/server_key/pattern/find.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use super::contains_cases;
use crate::integer::prelude::*;
use crate::integer::{BooleanBlock, RadixCiphertext};
use crate::strings::char_iter::CharIter;
use crate::strings::ciphertext::{FheAsciiChar, FheString, GenericPattern};
use crate::strings::server_key::pattern::IsMatch;
use crate::strings::server_key::{CharIter, FheStringIsEmpty, FheStringLen, ServerKey};
use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey};
use rayon::prelude::*;
use rayon::vec::IntoIter;

Expand All @@ -24,15 +25,11 @@ impl ServerKey {
let matched: Vec<_> = par_iter
.map(|start| {
let is_matched = if ignore_pat_pad {
let str_pat = str
.par_iter()
.copied()
.skip(start)
.zip(pat.par_iter().copied());
let str_pat = str.par_iter().skip(start).zip(pat.par_iter());

self.asciis_eq_ignore_pat_pad(str_pat)
} else {
self.asciis_eq(str.iter().skip(start).copied(), pat.iter().copied())
self.asciis_eq(str.into_iter().skip(start), pat.into_iter())
};

(start, is_matched)
Expand Down Expand Up @@ -66,9 +63,7 @@ impl ServerKey {

let matched: Vec<_> = par_iter
.map(|start| {
let str_chars = &str[start..];

let is_matched = self.clear_asciis_eq(str_chars.iter().copied(), pat);
let is_matched = self.clear_asciis_eq(str.into_iter().skip(start), pat);

(start, is_matched)
})
Expand Down
47 changes: 17 additions & 30 deletions tfhe/src/strings/server_key/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ mod split;
mod strip;

use crate::integer::BooleanBlock;
use crate::strings::char_iter::CharIter;
use crate::strings::ciphertext::{FheAsciiChar, FheString};
use crate::strings::server_key::{CharIter, FheStringIsEmpty, ServerKey};
use itertools::Itertools;
use crate::strings::server_key::{FheStringIsEmpty, ServerKey};
use std::ops::Range;

// Useful for handling cases in which we know if there is or there isn't a match just by looking at
Expand Down Expand Up @@ -73,8 +73,8 @@ fn ends_with_cases<'a>(
match (str.is_padded(), pat.is_padded()) {
// If neither has padding we just check if pat matches the `pat_len` last chars or str
(false, false) => {
str_chars = str.chars().iter().collect_vec();
pat_chars = pat.chars().iter().collect_vec();
str_chars = CharIter::new(str.chars(), None);
pat_chars = CharIter::new(pat.chars(), None);

let start = str_len - pat_len;

Expand All @@ -84,27 +84,18 @@ fn ends_with_cases<'a>(
// If only str is padded we have to check all the possible padding cases. If str is 3
// chars long, then it could be "xx\0", "x\0\0" or "\0\0\0", where x != '\0'
(true, false) => {
str_chars = str.chars()[..str_len - 1].iter().collect_vec();
pat_chars = pat
.chars()
.iter()
.chain(std::iter::once(null.unwrap()))
.collect_vec();
str_chars = CharIter::new(&str.chars()[..str_len - 1], None);
pat_chars = CharIter::new(pat.chars(), Some(null.unwrap()));

let diff = (str_len - 1) - pat_len;

range = 0..diff + 1;
}

// If only pat is padded we have to check all the possible padding cases as well
// If str = "abc" and pat = "abcd\0", we check if "abc\0" == pat[..4]
(false, true) => {
str_chars = str
.chars()
.iter()
.chain(std::iter::once(null.unwrap()))
.collect_vec();
pat_chars = pat.chars().iter().collect_vec();
str_chars = CharIter::new(str.chars(), Some(null.unwrap()));
pat_chars = CharIter::new(pat.chars(), None);

if pat_len - 1 > str_len {
// Pat without last char is longer than str so we check all the str chars
Expand All @@ -119,8 +110,8 @@ fn ends_with_cases<'a>(
}

(true, true) => {
str_chars = str.chars().iter().collect_vec();
pat_chars = pat.chars().iter().collect_vec();
str_chars = CharIter::new(str.chars(), None);
pat_chars = CharIter::new(pat.chars(), None);

range = 0..str_len;
}
Expand All @@ -137,15 +128,15 @@ fn clear_ends_with_cases<'a>(
let str_len = str.len();

if str.is_padded() {
let str_chars = str.chars()[..str_len - 1].iter().collect();
let str_chars = CharIter::new(&str.chars()[..str_len - 1], None);
let pat_chars = format!("{pat}\0");

let diff = (str_len - 1) - pat_len;
let range = 0..diff + 1;

(str_chars, pat_chars, range)
} else {
let str_chars = str.chars().iter().collect();
let str_chars = CharIter::new(str.chars(), None);

let start = str_len - pat_len;
let range = start..start + 1;
Expand All @@ -168,24 +159,20 @@ fn contains_cases<'a>(
let range;

if pat.is_padded() {
pat_chars = pat.chars()[..pat_len - 1].iter().collect_vec();
pat_chars = CharIter::new(&pat.chars()[..pat_len - 1], None);

if str.is_padded() {
str_chars = str.chars().iter().collect_vec();
str_chars = CharIter::new(str.chars(), None);

range = 0..str_len - 1;
} else {
str_chars = str
.chars()
.iter()
.chain(std::iter::once(null.unwrap()))
.collect_vec();
str_chars = CharIter::new(str.chars(), Some(null.unwrap()));

range = 0..str_len;
}
} else {
str_chars = str.chars().iter().collect_vec();
pat_chars = pat.chars().iter().collect_vec();
str_chars = CharIter::new(str.chars(), None);
pat_chars = CharIter::new(pat.chars(), None);

let diff = (str_len - pat_len) - if str.is_padded() { 1 } else { 0 };

Expand Down
Loading

0 comments on commit 8832fe8

Please sign in to comment.