Skip to content

Commit

Permalink
refactor(fhe_strings): CharIter is a Vec, cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Oct 24, 2024
1 parent be4be61 commit b917005
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 140 deletions.
19 changes: 1 addition & 18 deletions tfhe/examples/fhe_strings/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,4 @@ pub trait FheStringIterator {
fn next(&mut self, sk: &ServerKey) -> (FheString, BooleanBlock);
}

#[derive(Clone)]
enum CharIter<'a> {
Iter(std::slice::Iter<'a, FheAsciiChar>),
Extended(
std::iter::Chain<std::slice::Iter<'a, FheAsciiChar>, std::iter::Once<&'a FheAsciiChar>>,
),
}

impl<'a> Iterator for CharIter<'a> {
type Item = &'a FheAsciiChar;

fn next(&mut self) -> Option<Self::Item> {
match self {
CharIter::Iter(iter) => iter.next(),
CharIter::Extended(iter) => iter.next(),
}
}
}
type CharIter<'a> = Vec<&'a FheAsciiChar>;
33 changes: 15 additions & 18 deletions tfhe/examples/fhe_strings/server_key/pattern/contains.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::ciphertext::{FheAsciiChar, FheString, GenericPattern};
use crate::server_key::pattern::{CharIter, IsMatch};
use crate::server_key::ServerKey;
use itertools::Itertools;
use rayon::prelude::*;
use rayon::range::Iter;
use tfhe::integer::{BooleanBlock, IntegerRadixCiphertext, RadixCiphertext};
Expand All @@ -17,18 +18,16 @@ impl ServerKey {

let matched: Vec<_> = par_iter
.map(|start| {
let str_chars = str.clone().skip(start);
let pat_chars = pat.clone();

if ignore_pat_pad {
let str_pat = str_chars.into_iter().zip(pat_chars).par_bridge();
let str_chars = str
.par_iter()
.copied()
.skip(start)
.zip(pat.par_iter().copied());

self.asciis_eq_ignore_pat_pad(str_pat)
self.asciis_eq_ignore_pat_pad(str_chars)
} else {
let a: Vec<&FheAsciiChar> = str_chars.collect();
let b: Vec<&FheAsciiChar> = pat_chars.collect();

self.asciis_eq(a.into_iter(), b.into_iter())
self.asciis_eq(str.iter().copied().skip(start), pat.iter().copied())
}
})
.collect();
Expand All @@ -55,12 +54,7 @@ impl ServerKey {
let (str, pat) = str_pat;

let matched: Vec<_> = par_iter
.map(|start| {
let str_chars = str.clone().skip(start);
let a: Vec<&FheAsciiChar> = str_chars.collect();

self.clear_asciis_eq(a.into_iter(), pat)
})
.map(|start| self.clear_asciis_eq(str.iter().skip(start).copied(), pat))
.collect();

let block_vec: Vec<_> = matched
Expand Down Expand Up @@ -196,12 +190,15 @@ impl ServerKey {
let str_chars = if !str.is_padded() && (str_len < pat_len - 1) {
// If str = "xy" and pat = "xyz\0", then str[..] == pat[..2], but instead we have
// to check if "xy\0" == pat[..3] (i.e. check that the actual pattern isn't longer)
CharIter::Extended(str.chars().iter().chain(std::iter::once(&null)))
str.chars()
.iter()
.chain(std::iter::once(&null))
.collect_vec()
} else {
CharIter::Iter(str.chars().iter())
str.chars().iter().collect_vec()
};

let str_pat = str_chars.into_iter().zip(pat_chars).par_bridge();
let str_pat = str_chars.par_iter().copied().zip(pat_chars.par_iter());

self.asciis_eq_ignore_pat_pad(str_pat)
}
Expand Down
19 changes: 8 additions & 11 deletions tfhe/examples/fhe_strings/server_key/pattern/find.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,16 @@ impl ServerKey {

let matched: Vec<_> = par_iter
.map(|start| {
let str_chars = str.clone().skip(start);
let pat_chars = pat.clone();

let is_matched = if ignore_pat_pad {
let str_pat = str_chars.into_iter().zip(pat_chars).par_bridge();
let str_pat = str
.par_iter()
.copied()
.skip(start)
.zip(pat.par_iter().copied());

self.asciis_eq_ignore_pat_pad(str_pat)
} else {
let a: Vec<&FheAsciiChar> = str_chars.collect();
let b: Vec<&FheAsciiChar> = pat_chars.collect();

self.asciis_eq(a.into_iter(), b.into_iter())
self.asciis_eq(str.iter().skip(start).copied(), pat.iter().copied())
};

(start, is_matched)
Expand Down Expand Up @@ -68,10 +66,9 @@ impl ServerKey {

let matched: Vec<_> = par_iter
.map(|start| {
let str_chars = str.clone().skip(start);
let str_chars = &str[start..];

let a: Vec<&FheAsciiChar> = str_chars.collect();
let is_matched = self.clear_asciis_eq(a.into_iter(), pat);
let is_matched = self.clear_asciis_eq(str_chars.iter().copied(), pat);

(start, is_matched)
})
Expand Down
126 changes: 59 additions & 67 deletions tfhe/examples/fhe_strings/server_key/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod strip;

use crate::ciphertext::{FheAsciiChar, FheString};
use crate::server_key::{CharIter, FheStringIsEmpty, ServerKey};
use itertools::Itertools;
use std::ops::Range;
use tfhe::integer::BooleanBlock;

Expand Down Expand Up @@ -65,73 +66,68 @@ impl ServerKey {
let pat_len = pat.len();
let str_len = str.len();

let range;

let str_chars;
let pat_chars;

match (str.is_padded(), pat.is_padded()) {
// If neither has padding we just check if pat matches the `pat_len` last chars or str
(false, false) => {
let str_chars = str.chars().iter();
let pat_chars = pat.chars().iter();
str_chars = str.chars().iter().collect_vec();
pat_chars = pat.chars().iter().collect_vec();

let start = str_len - pat_len;

let range = start..start + 1;

(CharIter::Iter(str_chars), CharIter::Iter(pat_chars), range)
range = start..start + 1;
}

// If only str is padded we have to check all the possible padding cases. If str is 3
// chars long, then it could be "xx\0", "x\0\0" or "\0\0\0", where x != '\0'
(true, false) => {
let str_chars = str.chars()[..str_len - 1].iter();
let pat_chars = pat.chars().iter().chain(std::iter::once(null.unwrap()));
str_chars = str.chars()[..str_len - 1].iter().collect_vec();
pat_chars = pat
.chars()
.iter()
.chain(std::iter::once(null.unwrap()))
.collect_vec();

let diff = (str_len - 1) - pat_len;

let range = 0..diff + 1;

(
CharIter::Iter(str_chars),
CharIter::Extended(pat_chars),
range,
)
range = 0..diff + 1;
}

// If only pat is padded we have to check all the possible padding cases as well
// If str = "abc" and pat = "abcd\0", we check if "abc\0" == pat[..4]
(false, true) => {
let (str_chars, pat_chars, range) = if pat_len - 1 > str_len {
str_chars = str
.chars()
.iter()
.chain(std::iter::once(null.unwrap()))
.collect_vec();
pat_chars = pat.chars().iter().collect_vec();

if pat_len - 1 > str_len {
// Pat without last char is longer than str so we check all the str chars
(
str.chars().iter().chain(std::iter::once(null.unwrap())),
pat.chars().iter(),
0..str_len + 1,
)
range = 0..str_len + 1;
} else {
// Pat without last char is equal or shorter than str so we check the
// `pat_len` - 1 last chars of str
let start = str_len - (pat_len - 1);
(
str.chars().iter().chain(std::iter::once(null.unwrap())),
pat.chars()[..pat_len - 1].iter(),
start..start + pat_len,
)
};

(
CharIter::Extended(str_chars),
CharIter::Iter(pat_chars),
range,
)
range = start..start + pat_len;
};
}

(true, true) => {
let str_chars = str.chars().iter();
let pat_chars = pat.chars().iter();

let range = 0..str_len;
str_chars = str.chars().iter().collect_vec();
pat_chars = pat.chars().iter().collect_vec();

(CharIter::Iter(str_chars), CharIter::Iter(pat_chars), range)
range = 0..str_len;
}
}

(str_chars, pat_chars, range)
}

fn clear_ends_with_cases<'a>(
Expand All @@ -143,20 +139,20 @@ impl ServerKey {
let str_len = str.len();

if str.is_padded() {
let str_chars = str.chars()[..str_len - 1].iter();
let mut pat_chars = pat.to_owned();

pat_chars.push('\0');
let str_chars = str.chars()[..str_len - 1].iter().collect();
let pat_chars = format!("{pat}\0");

let diff = (str_len - 1) - pat_len;
let range = 0..diff + 1;

(CharIter::Iter(str_chars), pat_chars, range)
(str_chars, pat_chars, range)
} else {
let str_chars = str.chars().iter().collect();

let start = str_len - pat_len;
let range = start..start + 1;

(CharIter::Iter(str.chars().iter()), pat.to_owned(), range)
(str_chars, pat.to_owned(), range)
}
}

Expand All @@ -169,43 +165,39 @@ impl ServerKey {
let pat_len = pat.len();
let str_len = str.len();

let str_chars;
let pat_chars;

let range;

match (str.is_padded(), pat.is_padded()) {
(_, false) => {
let diff = (str_len - pat_len) - if str.is_padded() { 1 } else { 0 };
str_chars = str.chars().iter().collect_vec();
pat_chars = pat.chars().iter().collect_vec();

let range = 0..diff + 1;
let diff = (str_len - pat_len) - if str.is_padded() { 1 } else { 0 };

(
CharIter::Iter(str.chars().iter()),
CharIter::Iter(pat.chars().iter()),
range,
)
range = 0..diff + 1;
}

(true, true) => {
let pat_chars = pat.chars()[..pat_len - 1].iter();

let range = 0..str_len - 1;
str_chars = str.chars().iter().collect_vec();
pat_chars = pat.chars()[..pat_len - 1].iter().collect_vec();

(
CharIter::Iter(str.chars().iter()),
CharIter::Iter(pat_chars),
range,
)
range = 0..str_len - 1;
}

(false, true) => {
let pat_chars = pat.chars()[..pat_len - 1].iter();
let str_chars = str.chars().iter().chain(std::iter::once(null.unwrap()));
str_chars = str
.chars()
.iter()
.chain(std::iter::once(null.unwrap()))
.collect_vec();

let range = 0..str_len;
pat_chars = pat.chars()[..pat_len - 1].iter().collect_vec();

(
CharIter::Extended(str_chars),
CharIter::Iter(pat_chars),
range,
)
range = 0..str_len;
}
}

(str_chars, pat_chars, range)
}
}
Loading

0 comments on commit b917005

Please sign in to comment.