From 59eeb6319d4446b0afd88ca0513831997aa8be6e Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Thu, 24 Oct 2024 16:39:46 +0200 Subject: [PATCH] strings: use integer keys --- tfhe/src/strings/assert_functions/mod.rs | 84 +++++++------- tfhe/src/strings/ciphertext.rs | 20 ++-- tfhe/src/strings/client_key.rs | 29 +---- tfhe/src/strings/mod.rs | 11 +- tfhe/src/strings/server_key/comp.rs | 74 ++++++------ tfhe/src/strings/server_key/mod.rs | 89 +++++---------- tfhe/src/strings/server_key/no_patterns.rs | 96 +++++++++------- .../strings/server_key/pattern/contains.rs | 44 ++++--- tfhe/src/strings/server_key/pattern/find.rs | 61 +++++----- .../src/strings/server_key/pattern/replace.rs | 74 ++++++------ .../strings/server_key/pattern/split/mod.rs | 108 +++++++++--------- .../server_key/pattern/split/split_iters.rs | 89 +++++++++------ tfhe/src/strings/server_key/pattern/strip.rs | 52 +++++---- tfhe/src/strings/server_key/trim.rs | 105 ++++++++--------- 14 files changed, 456 insertions(+), 480 deletions(-) diff --git a/tfhe/src/strings/assert_functions/mod.rs b/tfhe/src/strings/assert_functions/mod.rs index 188cdc6806..62b755371f 100644 --- a/tfhe/src/strings/assert_functions/mod.rs +++ b/tfhe/src/strings/assert_functions/mod.rs @@ -93,7 +93,7 @@ impl Keys { let dec = match result { FheStringLen::NoPadding(clear_len) => clear_len, - FheStringLen::Padding(enc_len) => self.ck.key().decrypt_radix::(&enc_len) as usize, + FheStringLen::Padding(enc_len) => self.ck.decrypt_radix::(&enc_len) as usize, }; println!("\n\x1b[1mLen:\x1b[0m"); @@ -113,7 +113,7 @@ impl Keys { let dec = match result { FheStringIsEmpty::NoPadding(clear_len) => clear_len, - FheStringIsEmpty::Padding(enc_len) => self.ck.key().decrypt_bool(&enc_len), + FheStringIsEmpty::Padding(enc_len) => self.ck.decrypt_bool(&enc_len), }; println!("\n\x1b[1mIs_empty:\x1b[0m"); @@ -150,7 +150,7 @@ impl Keys { let result = self.sk.contains(&enc_str, &enc_pat); let end = Instant::now(); - let dec = self.ck.key().decrypt_bool(&result); + let dec = self.ck.decrypt_bool(&result); println!("\n\x1b[1mContains:\x1b[0m"); result_message_pat(str, pat, expected, dec, end.duration_since(start)); @@ -161,7 +161,7 @@ impl Keys { let result = self.sk.contains(&enc_str, &clear_pat); let end = Instant::now(); - let dec = self.ck.key().decrypt_bool(&result); + let dec = self.ck.decrypt_bool(&result); println!("\n\x1b[1mContains:\x1b[0m"); result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); @@ -186,7 +186,7 @@ impl Keys { let result = self.sk.ends_with(&enc_str, &enc_pat); let end = Instant::now(); - let dec = self.ck.key().decrypt_bool(&result); + let dec = self.ck.decrypt_bool(&result); println!("\n\x1b[1mEnds_with:\x1b[0m"); result_message_pat(str, pat, expected, dec, end.duration_since(start)); @@ -197,7 +197,7 @@ impl Keys { let result = self.sk.ends_with(&enc_str, &clear_pat); let end = Instant::now(); - let dec = self.ck.key().decrypt_bool(&result); + let dec = self.ck.decrypt_bool(&result); println!("\n\x1b[1mEnds_with:\x1b[0m"); result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); @@ -222,7 +222,7 @@ impl Keys { let result = self.sk.starts_with(&enc_str, &enc_pat); let end = Instant::now(); - let dec = self.ck.key().decrypt_bool(&result); + let dec = self.ck.decrypt_bool(&result); println!("\n\x1b[1mStarts_with:\x1b[0m"); result_message_pat(str, pat, expected, dec, end.duration_since(start)); @@ -233,7 +233,7 @@ impl Keys { let result = self.sk.starts_with(&enc_str, &clear_pat); let end = Instant::now(); - let dec = self.ck.key().decrypt_bool(&result); + let dec = self.ck.decrypt_bool(&result); println!("\n\x1b[1mStarts_with:\x1b[0m"); result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); @@ -252,8 +252,8 @@ impl Keys { let (index, is_some) = self.sk.find(&enc_str, &enc_pat); let end = Instant::now(); - let dec_index = self.ck.key().decrypt_radix::(&index); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_index = self.ck.decrypt_radix::(&index); + let dec_is_some = self.ck.decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_index as usize); @@ -266,8 +266,8 @@ impl Keys { let (index, is_some) = self.sk.find(&enc_str, &clear_pat); let end = Instant::now(); - let dec_index = self.ck.key().decrypt_radix::(&index); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_index = self.ck.decrypt_radix::(&index); + let dec_is_some = self.ck.decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_index as usize); @@ -288,8 +288,8 @@ impl Keys { let (index, is_some) = self.sk.rfind(&enc_str, &enc_pat); let end = Instant::now(); - let dec_index = self.ck.key().decrypt_radix::(&index); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_index = self.ck.decrypt_radix::(&index); + let dec_is_some = self.ck.decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_index as usize); @@ -302,8 +302,8 @@ impl Keys { let (index, is_some) = self.sk.rfind(&enc_str, &clear_pat); let end = Instant::now(); - let dec_index = self.ck.key().decrypt_radix::(&index); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_index = self.ck.decrypt_radix::(&index); + let dec_is_some = self.ck.decrypt_bool(&is_some); let dec = dec_is_some.then_some(dec_index as usize); @@ -331,7 +331,7 @@ impl Keys { let end = Instant::now(); let dec_result = self.ck.decrypt_ascii(&result); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_is_some = self.ck.decrypt_bool(&is_some); if !dec_is_some { // When it's None, the FheString returned is the original str assert_eq!(dec_result, str); @@ -349,7 +349,7 @@ impl Keys { let end = Instant::now(); let dec_result = self.ck.decrypt_ascii(&result); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_is_some = self.ck.decrypt_bool(&is_some); if !dec_is_some { // When it's None, the FheString returned is the original str assert_eq!(dec_result, str); @@ -381,7 +381,7 @@ impl Keys { let end = Instant::now(); let dec_result = self.ck.decrypt_ascii(&result); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_is_some = self.ck.decrypt_bool(&is_some); if !dec_is_some { // When it's None, the FheString returned is the original str assert_eq!(dec_result, str); @@ -399,7 +399,7 @@ impl Keys { let end = Instant::now(); let dec_result = self.ck.decrypt_ascii(&result); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_is_some = self.ck.decrypt_bool(&is_some); if !dec_is_some { // When it's None, the FheString returned is the original str assert_eq!(dec_result, str); @@ -430,7 +430,7 @@ impl Keys { let result = self.sk.eq_ignore_case(&enc_lhs, &enc_rhs); let end = Instant::now(); - let dec = self.ck.key().decrypt_bool(&result); + let dec = self.ck.decrypt_bool(&result); println!("\n\x1b[1mEq_ignore_case:\x1b[0m"); result_message_rhs(str, rhs, expected, dec, end.duration_since(start)); @@ -441,7 +441,7 @@ impl Keys { let result = self.sk.eq_ignore_case(&enc_lhs, &clear_rhs); let end = Instant::now(); - let dec = self.ck.key().decrypt_bool(&result); + let dec = self.ck.decrypt_bool(&result); println!("\n\x1b[1mEq_ignore_case:\x1b[0m"); result_message_clear_rhs(str, rhs, expected, dec, end.duration_since(start)); @@ -461,7 +461,7 @@ impl Keys { let result_eq = self.sk.eq(&enc_lhs, &enc_rhs); let end = Instant::now(); - let dec_eq = self.ck.key().decrypt_bool(&result_eq); + let dec_eq = self.ck.decrypt_bool(&result_eq); println!("\n\x1b[1mEq:\x1b[0m"); result_message_rhs(str, rhs, expected_eq, dec_eq, end.duration_since(start)); @@ -472,7 +472,7 @@ impl Keys { let result_eq = self.sk.eq(&enc_lhs, &clear_rhs); let end = Instant::now(); - let dec_eq = self.ck.key().decrypt_bool(&result_eq); + let dec_eq = self.ck.decrypt_bool(&result_eq); println!("\n\x1b[1mEq:\x1b[0m"); result_message_clear_rhs(str, rhs, expected_eq, dec_eq, end.duration_since(start)); @@ -485,7 +485,7 @@ impl Keys { let result_ne = self.sk.ne(&enc_lhs, &enc_rhs); let end = Instant::now(); - let dec_ne = self.ck.key().decrypt_bool(&result_ne); + let dec_ne = self.ck.decrypt_bool(&result_ne); println!("\n\x1b[1mNe:\x1b[0m"); result_message_rhs(str, rhs, expected_ne, dec_ne, end.duration_since(start)); @@ -496,7 +496,7 @@ impl Keys { let result_ne = self.sk.ne(&enc_lhs, &clear_rhs); let end = Instant::now(); - let dec_ne = self.ck.key().decrypt_bool(&result_ne); + let dec_ne = self.ck.decrypt_bool(&result_ne); println!("\n\x1b[1mNe:\x1b[0m"); result_message_clear_rhs(str, rhs, expected_ne, dec_ne, end.duration_since(start)); @@ -511,7 +511,7 @@ impl Keys { let result_ge = self.sk.ge(&enc_lhs, &enc_rhs); let end = Instant::now(); - let dec_ge = self.ck.key().decrypt_bool(&result_ge); + let dec_ge = self.ck.decrypt_bool(&result_ge); println!("\n\x1b[1mGe:\x1b[0m"); result_message_rhs(str, rhs, expected_ge, dec_ge, end.duration_since(start)); @@ -524,7 +524,7 @@ impl Keys { let result_le = self.sk.le(&enc_lhs, &enc_rhs); let end = Instant::now(); - let dec_le = self.ck.key().decrypt_bool(&result_le); + let dec_le = self.ck.decrypt_bool(&result_le); println!("\n\x1b[1mLe:\x1b[0m"); result_message_rhs(str, rhs, expected_le, dec_le, end.duration_since(start)); @@ -537,7 +537,7 @@ impl Keys { let result_gt = self.sk.gt(&enc_lhs, &enc_rhs); let end = Instant::now(); - let dec_gt = self.ck.key().decrypt_bool(&result_gt); + let dec_gt = self.ck.decrypt_bool(&result_gt); println!("\n\x1b[1mGt:\x1b[0m"); result_message_rhs(str, rhs, expected_gt, dec_gt, end.duration_since(start)); @@ -550,7 +550,7 @@ impl Keys { let result_lt = self.sk.lt(&enc_lhs, &enc_rhs); let end = Instant::now(); - let dec_lt = self.ck.key().decrypt_bool(&result_lt); + let dec_lt = self.ck.decrypt_bool(&result_lt); println!("\n\x1b[1mLt:\x1b[0m"); result_message_rhs(str, rhs, expected_lt, dec_lt, end.duration_since(start)); @@ -736,7 +736,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); let dec_result = self.ck.decrypt_ascii(result); if !dec_is_some { // When it's None, the FheString returned is always empty @@ -776,7 +776,7 @@ impl Keys { let dec_lhs = self.ck.decrypt_ascii(&lhs); let dec_rhs = self.ck.decrypt_ascii(&rhs); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_is_some = self.ck.decrypt_bool(&is_some); let dec = dec_is_some.then_some((dec_lhs.as_str(), dec_rhs.as_str())); @@ -804,7 +804,7 @@ impl Keys { let dec_lhs = self.ck.decrypt_ascii(&lhs); let dec_rhs = self.ck.decrypt_ascii(&rhs); - let dec_is_some = self.ck.key().decrypt_bool(&is_some); + let dec_is_some = self.ck.decrypt_bool(&is_some); let dec = dec_is_some.then_some((dec_lhs.as_str(), dec_rhs.as_str())); @@ -835,7 +835,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) @@ -873,7 +873,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) @@ -917,7 +917,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) @@ -961,7 +961,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) @@ -1005,7 +1005,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) @@ -1051,7 +1051,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) @@ -1097,7 +1097,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) @@ -1158,7 +1158,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) @@ -1204,7 +1204,7 @@ impl Keys { let dec: Vec<_> = results .iter() .map(|(result, is_some)| { - let dec_is_some = self.ck.key().decrypt_bool(is_some); + let dec_is_some = self.ck.decrypt_bool(is_some); dec_is_some.then_some(self.ck.decrypt_ascii(result)) }) diff --git a/tfhe/src/strings/ciphertext.rs b/tfhe/src/strings/ciphertext.rs index 2331614aec..9bdd35cc94 100644 --- a/tfhe/src/strings/ciphertext.rs +++ b/tfhe/src/strings/ciphertext.rs @@ -1,7 +1,8 @@ -use crate::strings::client_key::{ClientKey, EncU16}; -use crate::strings::server_key::ServerKey; +use crate::integer::{ + ClientKey, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, ServerKey, +}; +use crate::strings::client_key::EncU16; use crate::strings::N; -use crate::integer::{IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext}; /// Represents a encrypted ASCII character. #[derive(Clone)] @@ -57,7 +58,7 @@ impl FheAsciiChar { pub fn null(sk: &ServerKey) -> Self { Self { - enc_char: sk.key().create_trivial_zero_radix(4), + enc_char: sk.create_trivial_zero_radix(4), } } } @@ -81,7 +82,7 @@ impl FheString { let enc_string: Vec<_> = str .bytes() .map(|char| FheAsciiChar { - enc_char: server_key.key().create_trivial_radix(char, 4), + enc_char: server_key.create_trivial_radix(char, 4), }) .collect(); @@ -153,8 +154,7 @@ impl FheString { let mut uint = RadixCiphertext::from_blocks(blocks); if uint.blocks().is_empty() { - sk.key() - .extend_radix_with_trivial_zero_blocks_lsb_assign(&mut uint, 4); + sk.extend_radix_with_trivial_zero_blocks_lsb_assign(&mut uint, 4); } uint @@ -189,11 +189,13 @@ impl FheString { #[cfg(test)] mod tests { use super::*; - use crate::strings::server_key::gen_keys; + use crate::integer::{ClientKey, ServerKey}; + use crate::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; #[test] fn test_uint_conversion() { - let (ck, sk) = gen_keys(); + let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + let sk = ServerKey::new_radix_server_key(&ck); let str = "Los Sheikah fueron originalmente criados de la Diosa Hylia antes del sellado del \ diff --git a/tfhe/src/strings/client_key.rs b/tfhe/src/strings/client_key.rs index dc794ddc86..5a056fcac1 100644 --- a/tfhe/src/strings/client_key.rs +++ b/tfhe/src/strings/client_key.rs @@ -1,14 +1,6 @@ -use crate::integer::{ClientKey as FheClientKey, RadixCiphertext}; -use crate::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; +use crate::integer::{ClientKey, RadixCiphertext}; use crate::strings::ciphertext::{FheAsciiChar, FheString}; -/// Represents a client key for encryption and decryption of strings. -#[derive(serde::Serialize, serde::Deserialize, Clone)] -pub struct ClientKey { - key: FheClientKey, -} - -/// Encrypted u16 value. It contains an optional `max` to restrict the range of the value. pub struct EncU16 { cipher: RadixCiphertext, max: Option, @@ -25,17 +17,6 @@ impl EncU16 { } impl ClientKey { - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - Self { - key: FheClientKey::new(PARAM_MESSAGE_2_CARRY_2), - } - } - - pub fn key(&self) -> &FheClientKey { - &self.key - } - /// Encrypts an ASCII string, optionally padding it with the specified amount of 0s, and returns /// an [`FheString`]. /// @@ -51,14 +32,14 @@ impl ClientKey { let mut enc_string: Vec<_> = str .bytes() .map(|char| FheAsciiChar { - enc_char: self.key.encrypt_radix(char, 4), + enc_char: self.encrypt_radix(char, 4), }) .collect(); // Optional padding if let Some(count) = padding { let null = (0..count).map(|_| FheAsciiChar { - enc_char: self.key.encrypt_radix(0u8, 4), + enc_char: self.encrypt_radix(0u8, 4), }); enc_string.extend(null); @@ -81,7 +62,7 @@ impl ClientKey { .chars() .iter() .filter_map(|enc_char| { - let byte = self.key.decrypt_radix(enc_char.ciphertext()); + let byte = self.decrypt_radix(enc_char.ciphertext()); if byte == 0 { prev_was_null = true; @@ -123,7 +104,7 @@ impl ClientKey { } EncU16 { - cipher: self.key.encrypt_radix(val, 8), + cipher: self.encrypt_radix(val, 8), max, } } diff --git a/tfhe/src/strings/mod.rs b/tfhe/src/strings/mod.rs index 6e457a42a7..117e213598 100644 --- a/tfhe/src/strings/mod.rs +++ b/tfhe/src/strings/mod.rs @@ -1,9 +1,7 @@ +use crate::integer::{ClientKey, ServerKey}; +use crate::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; use crate::strings::ciphertext::{FheString, UIntArg}; -use crate::strings::client_key::ClientKey; -use crate::strings::server_key::{ - gen_keys, FheStringIsEmpty, FheStringIterator, FheStringLen, ServerKey, -}; -// use clap::{value_parser, Arg, Command}; +use crate::strings::server_key::{FheStringIsEmpty, FheStringIterator, FheStringLen}; use std::time::Instant; pub mod ciphertext; @@ -89,7 +87,8 @@ struct Keys { impl Keys { fn new() -> Self { - let (ck, sk) = gen_keys(); + let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + let sk = ServerKey::new_radix_server_key(&ck); Self { ck, sk } } diff --git a/tfhe/src/strings/server_key/comp.rs b/tfhe/src/strings/server_key/comp.rs index f71dc0c268..64f9abbcfb 100644 --- a/tfhe/src/strings/server_key/comp.rs +++ b/tfhe/src/strings/server_key/comp.rs @@ -12,9 +12,7 @@ impl ServerKey { if lhs_len == 0 || (lhs.is_padded() && lhs_len == 1) { return match self.is_empty(rhs) { FheStringIsEmpty::Padding(enc_val) => Some(enc_val), - FheStringIsEmpty::NoPadding(val) => { - Some(self.key.create_trivial_boolean_block(val)) - } + FheStringIsEmpty::NoPadding(val) => Some(self.create_trivial_boolean_block(val)), }; } @@ -23,15 +21,13 @@ impl ServerKey { if rhs_len == 0 || (rhs.is_padded() && rhs_len == 1) { return match self.is_empty(lhs) { FheStringIsEmpty::Padding(enc_val) => Some(enc_val), - FheStringIsEmpty::NoPadding(_) => { - Some(self.key.create_trivial_boolean_block(false)) - } + FheStringIsEmpty::NoPadding(_) => Some(self.create_trivial_boolean_block(false)), }; } // Two strings without padding that have different lengths cannot be equal if (!lhs.is_padded() && !rhs.is_padded()) && (lhs.len() != rhs.len()) { - return Some(self.key.create_trivial_boolean_block(false)); + return Some(self.create_trivial_boolean_block(false)); } // A string without padding cannot be equal to a string with padding that has the same or @@ -39,7 +35,7 @@ impl ServerKey { if (!lhs.is_padded() && rhs.is_padded()) && (rhs.len() <= lhs.len()) || (!rhs.is_padded() && lhs.is_padded()) && (lhs.len() <= rhs.len()) { - return Some(self.key.create_trivial_boolean_block(false)); + return Some(self.create_trivial_boolean_block(false)); } None @@ -55,17 +51,19 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("hello", "hello"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// /// let result = sk.eq(&enc_s1, &enc_s2); - /// let are_equal = ck.key().decrypt_bool(&result); + /// let are_equal = ck.decrypt_bool(&result); /// /// assert!(are_equal); /// ``` @@ -86,14 +84,14 @@ impl ServerKey { GenericPattern::Clear(rhs) => { let rhs_clear_uint = self.pad_cipher_and_cleartext_lsb(&mut lhs_uint, rhs.str()); - self.key.scalar_eq_parallelized(&lhs_uint, rhs_clear_uint) + self.scalar_eq_parallelized(&lhs_uint, rhs_clear_uint) } GenericPattern::Enc(rhs) => { let mut rhs_uint = rhs.to_uint(self); self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.key.eq_parallelized(&lhs_uint, &rhs_uint) + self.eq_parallelized(&lhs_uint, &rhs_uint) } } } @@ -109,24 +107,26 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("hello", "world"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// /// let result = sk.ne(&enc_s1, &enc_s2); - /// let are_not_equal = ck.key().decrypt_bool(&result); + /// let are_not_equal = ck.decrypt_bool(&result); /// /// assert!(are_not_equal); /// ``` pub fn ne(&self, lhs: &FheString, rhs: &GenericPattern) -> BooleanBlock { let eq = self.eq(lhs, rhs); - self.key.boolean_bitnot(&eq) + self.boolean_bitnot(&eq) } /// Returns `true` if the first encrypted string is less than the second encrypted string. @@ -136,17 +136,19 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("apple", "banana"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = FheString::new(&ck, s2, None); /// /// let result = sk.lt(&enc_s1, &enc_s2); - /// let is_lt = ck.key().decrypt_bool(&result); + /// let is_lt = ck.decrypt_bool(&result); /// /// assert!(is_lt); // "apple" is less than "banana" /// ``` @@ -156,7 +158,7 @@ impl ServerKey { self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.key.lt_parallelized(&lhs_uint, &rhs_uint) + self.lt_parallelized(&lhs_uint, &rhs_uint) } /// Returns `true` if the first encrypted string is greater than the second encrypted string. @@ -166,17 +168,19 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("banana", "apple"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = FheString::new(&ck, s2, None); /// /// let result = sk.gt(&enc_s1, &enc_s2); - /// let is_gt = ck.key().decrypt_bool(&result); + /// let is_gt = ck.decrypt_bool(&result); /// /// assert!(is_gt); // "banana" is greater than "apple" /// ``` @@ -186,7 +190,7 @@ impl ServerKey { self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.key.gt_parallelized(&lhs_uint, &rhs_uint) + self.gt_parallelized(&lhs_uint, &rhs_uint) } /// Returns `true` if the first encrypted string is less than or equal to the second encrypted @@ -197,17 +201,19 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("apple", "banana"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = FheString::new(&ck, s2, None); /// /// let result = sk.le(&enc_s1, &enc_s2); - /// let is_le = ck.key().decrypt_bool(&result); + /// let is_le = ck.decrypt_bool(&result); /// /// assert!(is_le); // "apple" is less than or equal to "banana" /// ``` @@ -217,7 +223,7 @@ impl ServerKey { self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.key.le_parallelized(&lhs_uint, &rhs_uint) + self.le_parallelized(&lhs_uint, &rhs_uint) } /// Returns `true` if the first encrypted string is greater than or equal to the second @@ -228,17 +234,19 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("banana", "apple"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = FheString::new(&ck, s2, None); /// /// let result = sk.ge(&enc_s1, &enc_s2); - /// let is_ge = ck.key().decrypt_bool(&result); + /// let is_ge = ck.decrypt_bool(&result); /// /// assert!(is_ge); // "banana" is greater than or equal to "apple" /// ``` @@ -248,6 +256,6 @@ impl ServerKey { self.pad_ciphertexts_lsb(&mut lhs_uint, &mut rhs_uint); - self.key.ge_parallelized(&lhs_uint, &rhs_uint) + self.ge_parallelized(&lhs_uint, &rhs_uint) } } diff --git a/tfhe/src/strings/server_key/mod.rs b/tfhe/src/strings/server_key/mod.rs index 7adf103916..3d89f3cdfe 100644 --- a/tfhe/src/strings/server_key/mod.rs +++ b/tfhe/src/strings/server_key/mod.rs @@ -7,38 +7,12 @@ pub use trim::split_ascii_whitespace; use crate::integer::bigint::static_unsigned::StaticUnsignedBigInt; use crate::integer::prelude::*; -use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey as FheServerKey}; +use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey}; use crate::strings::ciphertext::{FheAsciiChar, FheString}; -use crate::strings::client_key::ClientKey; use crate::strings::N; use rayon::prelude::*; use std::cmp::Ordering; -/// Represents a server key to operate homomorphically on [`FheString`]. -#[derive(serde::Serialize, serde::Deserialize, Clone)] -pub struct ServerKey { - key: FheServerKey, -} - -pub fn gen_keys() -> (ClientKey, ServerKey) { - let ck = ClientKey::new(); - let sk = ServerKey::new(&ck); - - (ck, sk) -} - -impl ServerKey { - pub fn new(from: &ClientKey) -> Self { - Self { - key: FheServerKey::new_radix_server_key(from.key()), - } - } - - pub fn key(&self) -> &FheServerKey { - &self.key - } -} - // With no padding, the length is just the vector's length (clear result). With padding it requires // homomorphically counting the non zero elements (encrypted result). pub enum FheStringLen { @@ -77,7 +51,7 @@ impl ServerKey { self.trim_ciphertexts_lsb(&mut uint_str, &mut uint_pat); - self.key.eq_parallelized(&uint_str, &uint_pat) + self.eq_parallelized(&uint_str, &uint_pat) } fn clear_asciis_eq<'a, I>(&self, str: I, pat: &str) -> BooleanBlock @@ -104,41 +78,38 @@ impl ServerKey { } Ordering::Greater => { let diff = str_block_len - pat_block_len; - self.key.trim_radix_blocks_lsb_assign(&mut uint_str, diff); + self.trim_radix_blocks_lsb_assign(&mut uint_str, diff); } Ordering::Equal => (), } let clear_pat_uint = self.pad_cipher_and_cleartext_lsb(&mut uint_str, clear_pat); - self.key.scalar_eq_parallelized(&uint_str, clear_pat_uint) + self.scalar_eq_parallelized(&uint_str, clear_pat_uint) } fn asciis_eq_ignore_pat_pad<'a, I>(&self, str_pat: I) -> BooleanBlock where I: ParallelIterator, { - let mut result = self.key.create_trivial_boolean_block(true); + let mut result = self.create_trivial_boolean_block(true); let eq_or_null_pat: Vec<_> = str_pat .map(|(str_char, pat_char)| { let (are_eq, pat_is_null) = rayon::join( - || { - self.key - .eq_parallelized(str_char.ciphertext(), pat_char.ciphertext()) - }, - || self.key.scalar_eq_parallelized(pat_char.ciphertext(), 0u8), + || self.eq_parallelized(str_char.ciphertext(), pat_char.ciphertext()), + || self.scalar_eq_parallelized(pat_char.ciphertext(), 0u8), ); // If `pat_char` is null then `are_eq` is set to true. Hence if ALL `pat_char`s are // null, the result is always true, which is correct since the pattern is empty - self.key.boolean_bitor(&are_eq, &pat_is_null) + self.boolean_bitor(&are_eq, &pat_is_null) }) .collect(); for eq_or_null in eq_or_null_pat { // Will be false if `str_char` != `pat_char` and `pat_char` isn't null - self.key.boolean_bitand_assign(&mut result, &eq_or_null); + self.boolean_bitand_assign(&mut result, &eq_or_null); } result @@ -161,8 +132,7 @@ impl ServerKey { // Also fill the lhs with null blocks at the end if lhs.blocks().len() < N * 8 * 4 { let diff = N * 8 * 4 - lhs.blocks().len(); - self.key - .extend_radix_with_trivial_zero_blocks_lsb_assign(lhs, diff); + self.extend_radix_with_trivial_zero_blocks_lsb_assign(lhs, diff); } rhs_clear_uint @@ -175,13 +145,11 @@ impl ServerKey { match lhs_blocks.cmp(&rhs_blocks) { Ordering::Less => { let diff = rhs_blocks - lhs_blocks; - self.key - .extend_radix_with_trivial_zero_blocks_lsb_assign(lhs, diff); + self.extend_radix_with_trivial_zero_blocks_lsb_assign(lhs, diff); } Ordering::Greater => { let diff = lhs_blocks - rhs_blocks; - self.key - .extend_radix_with_trivial_zero_blocks_lsb_assign(rhs, diff); + self.extend_radix_with_trivial_zero_blocks_lsb_assign(rhs, diff); } Ordering::Equal => (), } @@ -193,12 +161,11 @@ impl ServerKey { match cipher_len.cmp(&len) { Ordering::Less => { let diff = len - cipher_len; - self.key - .extend_radix_with_trivial_zero_blocks_msb_assign(cipher, diff); + self.extend_radix_with_trivial_zero_blocks_msb_assign(cipher, diff); } Ordering::Greater => { let diff = cipher_len - len; - self.key.trim_radix_blocks_msb_assign(cipher, diff); + self.trim_radix_blocks_msb_assign(cipher, diff); } Ordering::Equal => (), } @@ -211,11 +178,11 @@ impl ServerKey { match lhs_blocks.cmp(&rhs_blocks) { Ordering::Less => { let diff = rhs_blocks - lhs_blocks; - self.key.trim_radix_blocks_lsb_assign(rhs, diff); + self.trim_radix_blocks_lsb_assign(rhs, diff); } Ordering::Greater => { let diff = lhs_blocks - rhs_blocks; - self.key.trim_radix_blocks_lsb_assign(lhs, diff); + self.trim_radix_blocks_lsb_assign(lhs, diff); } Ordering::Equal => (), } @@ -235,9 +202,7 @@ impl ServerKey { self.pad_ciphertexts_lsb(&mut true_ct_uint, &mut false_ct_uint); - let result_uint = - self.key - .if_then_else_parallelized(condition, &true_ct_uint, &false_ct_uint); + let result_uint = self.if_then_else_parallelized(condition, &true_ct_uint, &false_ct_uint); let mut result = FheString::from_uint(result_uint, false); if padded { @@ -253,21 +218,21 @@ impl ServerKey { fn left_shift_chars(&self, str: &FheString, shift: &RadixCiphertext) -> FheString { let uint = str.to_uint(self); - let mut shift_bits = self.key.scalar_left_shift_parallelized(shift, 3); + let mut shift_bits = self.scalar_left_shift_parallelized(shift, 3); // `shift_bits` needs to have the same block len as `uint` for the tfhe-rs shift to work self.pad_or_trim_ciphertext(&mut shift_bits, uint.blocks().len()); - let shifted = self.key.left_shift_parallelized(&uint, &shift_bits); + let shifted = self.left_shift_parallelized(&uint, &shift_bits); // If the shifting amount is >= than the str length we get zero i.e. all chars are out of // range (instead of wrapping, which is the behavior of Rust and tfhe-rs) let bit_len = (str.len() * 8) as u32; - let shift_ge_than_str = self.key.scalar_ge_parallelized(&shift_bits, bit_len); + let shift_ge_than_str = self.scalar_ge_parallelized(&shift_bits, bit_len); - let result = self.key.if_then_else_parallelized( + let result = self.if_then_else_parallelized( &shift_ge_than_str, - &self.key.create_trivial_zero_radix(uint.blocks().len()), + &self.create_trivial_zero_radix(uint.blocks().len()), &shifted, ); @@ -276,21 +241,21 @@ impl ServerKey { fn right_shift_chars(&self, str: &FheString, shift: &RadixCiphertext) -> FheString { let uint = str.to_uint(self); - let mut shift_bits = self.key.scalar_left_shift_parallelized(shift, 3); + let mut shift_bits = self.scalar_left_shift_parallelized(shift, 3); // `shift_bits` needs to have the same block len as `uint` for the tfhe-rs shift to work self.pad_or_trim_ciphertext(&mut shift_bits, uint.blocks().len()); - let shifted = self.key.right_shift_parallelized(&uint, &shift_bits); + let shifted = self.right_shift_parallelized(&uint, &shift_bits); // If the shifting amount is >= than the str length we get zero i.e. all chars are out of // range (instead of wrapping, which is the behavior of Rust and tfhe-rs) let bit_len = (str.len() * 8) as u32; - let shift_ge_than_str = self.key.scalar_ge_parallelized(&shift_bits, bit_len); + let shift_ge_than_str = self.scalar_ge_parallelized(&shift_bits, bit_len); - let result = self.key.if_then_else_parallelized( + let result = self.if_then_else_parallelized( &shift_ge_than_str, - &self.key.create_trivial_zero_radix(uint.blocks().len()), + &self.create_trivial_zero_radix(uint.blocks().len()), &shifted, ); diff --git a/tfhe/src/strings/server_key/no_patterns.rs b/tfhe/src/strings/server_key/no_patterns.rs index b4ebd002a0..59f2c3f9fd 100644 --- a/tfhe/src/strings/server_key/no_patterns.rs +++ b/tfhe/src/strings/server_key/no_patterns.rs @@ -13,10 +13,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::{gen_keys, FheStringLen}; - /// - /// let (ck, sk) = gen_keys(); + /// use tfhe::strings::server_key::FheStringLen; + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let s = "hello"; /// let number_of_nulls = 3; /// @@ -35,7 +37,7 @@ impl ServerKey { /// FheStringLen::NoPadding(_) => panic!("Unexpected no padding"), /// FheStringLen::Padding(ciphertext) => { /// // Homomorphically computed length, requires decryption for actual length - /// let length = ck.key().decrypt_radix::(&ciphertext); + /// let length = ck.decrypt_radix::(&ciphertext); /// assert_eq!(length, 5) /// } /// } @@ -46,14 +48,13 @@ impl ServerKey { .chars() .par_iter() .map(|char| { - let bool = self.key.scalar_ne_parallelized(char.ciphertext(), 0u8); - bool.into_radix(16, &self.key) + let bool = self.scalar_ne_parallelized(char.ciphertext(), 0u8); + bool.into_radix(16, self) }) .collect(); // If we add the number of non-zero elements we get the actual length, without padding let len = self - .key .sum_ciphertexts_parallelized(non_zero_chars.iter()) .expect("There's at least one padding character"); @@ -72,10 +73,13 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::{gen_keys, FheStringIsEmpty}; + /// use tfhe::strings::server_key::FheStringIsEmpty; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let s = ""; /// let number_of_nulls = 2; /// @@ -94,7 +98,7 @@ impl ServerKey { /// FheStringIsEmpty::NoPadding(_) => panic!("Unexpected no padding"), /// FheStringIsEmpty::Padding(ciphertext) => { /// // Homomorphically computed emptiness, requires decryption for actual value - /// let is_empty = ck.key().decrypt_bool(&ciphertext); + /// let is_empty = ck.decrypt_bool(&ciphertext); /// assert!(is_empty) /// } /// } @@ -102,11 +106,11 @@ impl ServerKey { pub fn is_empty(&self, str: &FheString) -> FheStringIsEmpty { if str.is_padded() { if str.len() == 1 { - return FheStringIsEmpty::Padding(self.key.create_trivial_boolean_block(true)); + return FheStringIsEmpty::Padding(self.create_trivial_boolean_block(true)); } let str_uint = str.to_uint(self); - let result = self.key.scalar_eq_parallelized(&str_uint, 0u8); + let result = self.scalar_eq_parallelized(&str_uint, 0u8); FheStringIsEmpty::Padding(result) } else { @@ -119,10 +123,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let s = "Hello World"; /// /// let enc_s = FheString::new(&ck, s, None); @@ -141,11 +147,11 @@ impl ServerKey { .par_iter() .map(|char| { let (ge_97, le_122) = rayon::join( - || self.key.scalar_ge_parallelized(char.ciphertext(), 97u8), - || self.key.scalar_le_parallelized(char.ciphertext(), 122u8), + || self.scalar_ge_parallelized(char.ciphertext(), 97u8), + || self.scalar_le_parallelized(char.ciphertext(), 122u8), ); - self.key.boolean_bitand(&ge_97, &le_122) + self.boolean_bitand(&ge_97, &le_122) }) .collect(); @@ -155,13 +161,11 @@ impl ServerKey { .par_iter_mut() .zip(lowercase_chars.into_par_iter()) .for_each(|(char, is_lowercase)| { - let mut subtract = self.key.create_trivial_radix(32, 4); + let mut subtract = self.create_trivial_radix(32, 4); - self.key - .mul_assign_parallelized(&mut subtract, &is_lowercase.into_radix(1, &self.key)); + self.mul_assign_parallelized(&mut subtract, &is_lowercase.into_radix(1, self)); - self.key - .sub_assign_parallelized(char.ciphertext_mut(), &subtract); + self.sub_assign_parallelized(char.ciphertext_mut(), &subtract); }); uppercase @@ -172,10 +176,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let s = "Hello World"; /// /// let enc_s = FheString::new(&ck, s, None); @@ -194,11 +200,11 @@ impl ServerKey { .par_iter() .map(|char| { let (ge_65, le_90) = rayon::join( - || self.key.scalar_ge_parallelized(char.ciphertext(), 65u8), - || self.key.scalar_le_parallelized(char.ciphertext(), 90u8), + || self.scalar_ge_parallelized(char.ciphertext(), 65u8), + || self.scalar_le_parallelized(char.ciphertext(), 90u8), ); - self.key.boolean_bitand(&ge_65, &le_90) + self.boolean_bitand(&ge_65, &le_90) }) .collect(); @@ -208,13 +214,11 @@ impl ServerKey { .par_iter_mut() .zip(uppercase_chars) .for_each(|(char, is_uppercase)| { - let mut add = self.key.create_trivial_radix(32, 4); + let mut add = self.create_trivial_radix(32, 4); - self.key - .mul_assign_parallelized(&mut add, &is_uppercase.into_radix(1, &self.key)); + self.mul_assign_parallelized(&mut add, &is_uppercase.into_radix(1, self)); - self.key - .add_assign_parallelized(char.ciphertext_mut(), &add); + self.add_assign_parallelized(char.ciphertext_mut(), &add); }); lowercase @@ -231,17 +235,19 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s1, s2) = ("Hello", "hello"); /// /// let enc_s1 = FheString::new(&ck, s1, None); /// let enc_s2 = GenericPattern::Enc(FheString::new(&ck, s2, None)); /// /// let result = sk.eq_ignore_case(&enc_s1, &enc_s2); - /// let are_equal = ck.key().decrypt_bool(&result); + /// let are_equal = ck.decrypt_bool(&result); /// /// assert!(are_equal); /// ``` @@ -266,10 +272,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (lhs, rhs) = ("Hello, ", "world!"); /// /// let enc_lhs = FheString::new(&ck, lhs, None); @@ -293,8 +301,8 @@ impl ServerKey { // If lhs is padded we can shift it right such that all nulls move to the start, then // we append the rhs and shift it left again to move the nulls to the new end FheStringLen::Padding(len) => { - let padded_len = self.key.create_trivial_radix(lhs.len() as u32, 16); - let number_of_nulls = self.key.sub_parallelized(&padded_len, &len); + let padded_len = self.create_trivial_radix(lhs.len() as u32, 16); + let number_of_nulls = self.sub_parallelized(&padded_len, &len); result = self.right_shift_chars(&result, &number_of_nulls); @@ -317,10 +325,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, UIntArg}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let s = "hi"; /// /// let enc_s = FheString::new(&ck, s, None); @@ -361,11 +371,11 @@ impl ServerKey { } } UIntArg::Enc(enc_n) => { - let n_is_zero = self.key.scalar_eq_parallelized(enc_n.cipher(), 0); + let n_is_zero = self.scalar_eq_parallelized(enc_n.cipher(), 0); result = self.conditional_string(&n_is_zero, FheString::empty(), &result); for i in 0..enc_n.max().unwrap_or(u16::MAX) - 1 { - let n_is_exceeded = self.key.scalar_le_parallelized(enc_n.cipher(), i + 1); + let n_is_exceeded = self.scalar_le_parallelized(enc_n.cipher(), i + 1); let append = self.conditional_string(&n_is_exceeded, FheString::empty(), str); result = self.concat(&result, &append); diff --git a/tfhe/src/strings/server_key/pattern/contains.rs b/tfhe/src/strings/server_key/pattern/contains.rs index 6ec4856276..a04a288ae7 100644 --- a/tfhe/src/strings/server_key/pattern/contains.rs +++ b/tfhe/src/strings/server_key/pattern/contains.rs @@ -36,7 +36,7 @@ impl ServerKey { let block_vec: Vec<_> = matched .into_iter() .map(|bool| { - let radix: RadixCiphertext = bool.into_radix(1, &self.key); + let radix: RadixCiphertext = bool.into_radix(1, self); radix.into_blocks()[0].clone() }) .collect(); @@ -44,7 +44,7 @@ impl ServerKey { // This will be 0 if there was no match, non-zero otherwise let combined_radix = RadixCiphertext::from(block_vec); - self.key.scalar_ne_parallelized(&combined_radix, 0) + self.scalar_ne_parallelized(&combined_radix, 0) } fn clear_compare_shifted( @@ -61,7 +61,7 @@ impl ServerKey { let block_vec: Vec<_> = matched .into_iter() .map(|bool| { - let radix: RadixCiphertext = bool.into_radix(1, &self.key); + let radix: RadixCiphertext = bool.into_radix(1, self); radix.into_blocks()[0].clone() }) .collect(); @@ -69,7 +69,7 @@ impl ServerKey { // This will be 0 if there was no match, non-zero otherwise let combined_radix = RadixCiphertext::from(block_vec); - self.key.scalar_ne_parallelized(&combined_radix, 0) + self.scalar_ne_parallelized(&combined_radix, 0) } /// Returns `true` if the given pattern (either encrypted or clear) matches a substring of this @@ -83,10 +83,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{ClearString, FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (bananas, nana, apples) = ("bananas", "nana", "apples"); /// /// let enc_bananas = FheString::new(&ck, bananas, None); @@ -96,8 +98,8 @@ impl ServerKey { /// let result1 = sk.contains(&enc_bananas, &enc_nana); /// let result2 = sk.contains(&enc_bananas, &clear_apples); /// - /// let should_be_true = ck.key().decrypt_bool(&result1); - /// let should_be_false = ck.key().decrypt_bool(&result2); + /// let should_be_true = ck.decrypt_bool(&result1); + /// let should_be_false = ck.decrypt_bool(&result2); /// /// assert!(should_be_true); /// assert!(!should_be_false); @@ -109,7 +111,7 @@ impl ServerKey { }; match self.length_checks(str, &trivial_or_enc_pat) { - IsMatch::Clear(val) => return self.key.create_trivial_boolean_block(val), + IsMatch::Clear(val) => return self.create_trivial_boolean_block(val), IsMatch::Cipher(val) => return val, IsMatch::None => (), } @@ -142,10 +144,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{ClearString, FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (bananas, ba, nan) = ("bananas", "ba", "nan"); /// /// let enc_bananas = FheString::new(&ck, bananas, None); @@ -155,8 +159,8 @@ impl ServerKey { /// let result1 = sk.starts_with(&enc_bananas, &enc_ba); /// let result2 = sk.starts_with(&enc_bananas, &clear_nan); /// - /// let should_be_true = ck.key().decrypt_bool(&result1); - /// let should_be_false = ck.key().decrypt_bool(&result2); + /// let should_be_true = ck.decrypt_bool(&result1); + /// let should_be_false = ck.decrypt_bool(&result2); /// /// assert!(should_be_true); /// assert!(!should_be_false); @@ -168,7 +172,7 @@ impl ServerKey { }; match self.length_checks(str, &trivial_or_enc_pat) { - IsMatch::Clear(val) => return self.key.create_trivial_boolean_block(val), + IsMatch::Clear(val) => return self.create_trivial_boolean_block(val), IsMatch::Cipher(val) => return val, IsMatch::None => (), } @@ -214,10 +218,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{ClearString, FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (bananas, anas, nana) = ("bananas", "anas", "nana"); /// /// let enc_bananas = FheString::new(&ck, bananas, None); @@ -227,8 +233,8 @@ impl ServerKey { /// let result1 = sk.ends_with(&enc_bananas, &enc_anas); /// let result2 = sk.ends_with(&enc_bananas, &clear_nana); /// - /// let should_be_true = ck.key().decrypt_bool(&result1); - /// let should_be_false = ck.key().decrypt_bool(&result2); + /// let should_be_true = ck.decrypt_bool(&result1); + /// let should_be_false = ck.decrypt_bool(&result2); /// /// assert!(should_be_true); /// assert!(!should_be_false); @@ -240,7 +246,7 @@ impl ServerKey { }; match self.length_checks(str, &trivial_or_enc_pat) { - IsMatch::Clear(val) => return self.key.create_trivial_boolean_block(val), + IsMatch::Clear(val) => return self.create_trivial_boolean_block(val), IsMatch::Cipher(val) => return val, IsMatch::None => (), } diff --git a/tfhe/src/strings/server_key/pattern/find.rs b/tfhe/src/strings/server_key/pattern/find.rs index 2d54be2ee0..a807f5c301 100644 --- a/tfhe/src/strings/server_key/pattern/find.rs +++ b/tfhe/src/strings/server_key/pattern/find.rs @@ -17,8 +17,8 @@ impl ServerKey { par_iter: IntoIter, ignore_pat_pad: bool, ) -> (RadixCiphertext, BooleanBlock) { - let mut result = self.key.create_trivial_boolean_block(false); - let mut last_match_index = self.key.create_trivial_zero_radix(16); + let mut result = self.create_trivial_boolean_block(false); + let mut last_match_index = self.create_trivial_zero_radix(16); let (str, pat) = str_pat; let matched: Vec<_> = par_iter @@ -40,16 +40,15 @@ impl ServerKey { .collect(); for (i, is_matched) in matched { - let index = self.key.create_trivial_radix(i as u32, 16); + let index = self.create_trivial_radix(i as u32, 16); rayon::join( || { last_match_index = - self.key - .if_then_else_parallelized(&is_matched, &index, &last_match_index) + self.if_then_else_parallelized(&is_matched, &index, &last_match_index) }, // One of the possible values of the padded pat must match the str - || self.key.boolean_bitor_assign(&mut result, &is_matched), + || self.boolean_bitor_assign(&mut result, &is_matched), ); } @@ -61,8 +60,8 @@ impl ServerKey { str_pat: (CharIter, &str), par_iter: IntoIter, ) -> (RadixCiphertext, BooleanBlock) { - let mut result = self.key.create_trivial_boolean_block(false); - let mut last_match_index = self.key.create_trivial_zero_radix(16); + let mut result = self.create_trivial_boolean_block(false); + let mut last_match_index = self.create_trivial_zero_radix(16); let (str, pat) = str_pat; let matched: Vec<_> = par_iter @@ -76,16 +75,15 @@ impl ServerKey { .collect(); for (i, is_matched) in matched { - let index = self.key.create_trivial_radix(i as u32, 16); + let index = self.create_trivial_radix(i as u32, 16); rayon::join( || { last_match_index = - self.key - .if_then_else_parallelized(&is_matched, &index, &last_match_index) + self.if_then_else_parallelized(&is_matched, &index, &last_match_index) }, // One of the possible values of the padded pat must match the str - || self.key.boolean_bitor_assign(&mut result, &is_matched), + || self.boolean_bitor_assign(&mut result, &is_matched), ); } @@ -105,10 +103,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (haystack, needle) = ("hello world", "world"); /// /// let enc_haystack = FheString::new(&ck, haystack, None); @@ -116,8 +116,8 @@ impl ServerKey { /// /// let (index, found) = sk.find(&enc_haystack, &enc_needle); /// - /// let index = ck.key().decrypt_radix::(&index); - /// let found = ck.key().decrypt_bool(&found); + /// let index = ck.decrypt_radix::(&index); + /// let found = ck.decrypt_bool(&found); /// /// assert!(found); /// assert_eq!(index, 6); // "world" starts at index 6 in "hello world" @@ -128,11 +128,11 @@ impl ServerKey { GenericPattern::Enc(pat) => pat.clone(), }; - let zero = self.key.create_trivial_zero_radix(16); + let zero = self.create_trivial_zero_radix(16); match self.length_checks(str, &trivial_or_enc_pat) { // bool is true if pattern is empty, in which the first match index is 0. If it's false // we default to 0 as well - IsMatch::Clear(bool) => return (zero, self.key.create_trivial_boolean_block(bool)), + IsMatch::Clear(bool) => return (zero, self.create_trivial_boolean_block(bool)), // This variant is only returned in the empty string case so in any case index is 0 IsMatch::Cipher(val) => return (zero, val), @@ -173,10 +173,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (haystack, needle) = ("hello world world", "world"); /// /// let enc_haystack = FheString::new(&ck, haystack, None); @@ -184,8 +186,8 @@ impl ServerKey { /// /// let (index, found) = sk.rfind(&enc_haystack, &enc_needle); /// - /// let index = ck.key().decrypt_radix::(&index); - /// let found = ck.key().decrypt_bool(&found); + /// let index = ck.decrypt_radix::(&index); + /// let found = ck.decrypt_bool(&found); /// /// assert!(found); /// assert_eq!(index, 12); // The last "world" starts at index 12 in "hello world world" @@ -196,22 +198,20 @@ impl ServerKey { GenericPattern::Enc(pat) => pat.clone(), }; - let zero = self.key.create_trivial_zero_radix(16); + let zero = self.create_trivial_zero_radix(16); match self.length_checks(str, &trivial_or_enc_pat) { IsMatch::Clear(val) => { // val = true if pattern is empty, in which the last match index = str.len() let index = if val { match self.len(str) { FheStringLen::Padding(cipher_len) => cipher_len, - FheStringLen::NoPadding(len) => { - self.key.create_trivial_radix(len as u32, 16) - } + FheStringLen::NoPadding(len) => self.create_trivial_radix(len as u32, 16), } } else { zero }; - return (index, self.key.create_trivial_boolean_block(val)); + return (index, self.create_trivial_boolean_block(val)); } // This variant is only returned in the empty string case so in any case index is 0 @@ -256,9 +256,7 @@ impl ServerKey { if str.is_padded() && padded_pat_is_empty.is_some() { let str_true_len = match self.len(str) { FheStringLen::Padding(cipher_len) => cipher_len, - FheStringLen::NoPadding(len) => { - self.key.create_trivial_radix(len as u32, 16) - } + FheStringLen::NoPadding(len) => self.create_trivial_radix(len as u32, 16), }; Some((padded_pat_is_empty.unwrap(), str_true_len)) @@ -270,8 +268,7 @@ impl ServerKey { if let Some((pat_is_empty, str_true_len)) = option { last_match_index = - self.key - .if_then_else_parallelized(&pat_is_empty, &str_true_len, &last_match_index); + self.if_then_else_parallelized(&pat_is_empty, &str_true_len, &last_match_index); } (last_match_index, result) diff --git a/tfhe/src/strings/server_key/pattern/replace.rs b/tfhe/src/strings/server_key/pattern/replace.rs index d98d98990d..a51e7ab9a8 100644 --- a/tfhe/src/strings/server_key/pattern/replace.rs +++ b/tfhe/src/strings/server_key/pattern/replace.rs @@ -1,8 +1,8 @@ +use crate::integer::prelude::*; +use crate::integer::{BooleanBlock, RadixCiphertext}; use crate::strings::ciphertext::{FheString, GenericPattern, UIntArg}; use crate::strings::server_key::pattern::IsMatch; use crate::strings::server_key::{FheStringIsEmpty, FheStringLen, ServerKey}; -use crate::integer::prelude::*; -use crate::integer::{BooleanBlock, RadixCiphertext}; impl ServerKey { // Replaces the pattern ignoring the first `start` chars (i.e. these are not replaced) @@ -30,10 +30,10 @@ impl ServerKey { let (mut replaced, rhs) = rayon::join( || { - let str_len = self.key.create_trivial_radix(str.len() as u32, 16); + let str_len = self.create_trivial_radix(str.len() as u32, 16); // Get the [lhs] shifting right by [from, rhs].len() - let shift_right = self.key.sub_parallelized(&str_len, find_index); + let shift_right = self.sub_parallelized(&str_len, find_index); let mut lhs = self.right_shift_chars(str, &shift_right); // As lhs is shifted right we know there aren't nulls on the right, unless empty lhs.set_is_padded(false); @@ -50,11 +50,9 @@ impl ServerKey { // Get the [rhs] shifting left by [lhs, from].len() let shift_left = match from_len { FheStringLen::NoPadding(len) => { - self.key.scalar_add_parallelized(find_index, *len as u32) - } - FheStringLen::Padding(enc_len) => { - self.key.add_parallelized(find_index, enc_len) + self.scalar_add_parallelized(find_index, *len as u32) } + FheStringLen::Padding(enc_len) => self.add_parallelized(find_index, enc_len), }; let mut rhs = self.left_shift_chars(str, &shift_left); @@ -71,12 +69,12 @@ impl ServerKey { || self.conditional_string(replace, replaced, str), || { // If there's match we return [lhs, to].len(), else we return 0 (index default) - let add_to_index = self.key.if_then_else_parallelized( + let add_to_index = self.if_then_else_parallelized( replace, enc_to_len, - &self.key.create_trivial_zero_radix(16), + &self.create_trivial_zero_radix(16), ); - self.key.add_parallelized(find_index, &add_to_index) + self.add_parallelized(find_index, &add_to_index) }, ) } @@ -89,7 +87,7 @@ impl ServerKey { to: &FheString, enc_n: Option<&RadixCiphertext>, ) { - let mut skip = self.key.create_trivial_zero_radix(16); + let mut skip = self.create_trivial_zero_radix(16); let trivial_or_enc_from = match from { GenericPattern::Clear(from) => FheString::trivial(self, from.str()), GenericPattern::Enc(from) => from.clone(), @@ -107,9 +105,7 @@ impl ServerKey { || self.len(result), || match self.len(to) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => { - self.key.create_trivial_radix(val as u32, 16) - } + FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), }, ) }, @@ -127,7 +123,7 @@ impl ServerKey { let (mut index, is_match) = self.find(&shifted_str, from); // We add `skip` to get the actual index of the pattern (in the non shifted str) - self.key.add_assign_parallelized(&mut index, &skip); + self.add_assign_parallelized(&mut index, &skip); (*result, skip) = self.replace_once(&is_match, &index, &from_len, &enc_to_len, result, to); @@ -142,13 +138,12 @@ impl ServerKey { // If we replace "" to "a" in the "ww" str, we get "awawa". So when `from_is_empty` // we need to move to the next space between letters by adding 1 to the skip value || match &from_is_empty { - FheStringIsEmpty::Padding(enc) => self.key.add_assign_parallelized( + FheStringIsEmpty::Padding(enc) => self.add_assign_parallelized( &mut skip, - &enc.clone().into_radix(num_blocks, &self.key), + &enc.clone().into_radix(num_blocks, self), ), FheStringIsEmpty::NoPadding(clear) => { - self.key - .scalar_add_assign_parallelized(&mut skip, *clear as u8); + self.scalar_add_assign_parallelized(&mut skip, *clear as u8); } }, ); @@ -171,29 +166,26 @@ impl ServerKey { || { let no_more_matches = match &str_len { FheStringLen::Padding(enc) => { - self.key.scalar_lt_parallelized(enc, current_iteration) + self.scalar_lt_parallelized(enc, current_iteration) + } + FheStringLen::NoPadding(clear) => { + self.create_trivial_boolean_block(*clear < current_iteration as usize) } - FheStringLen::NoPadding(clear) => self - .key - .create_trivial_boolean_block(*clear < current_iteration as usize), }; match &from_is_empty { - FheStringIsEmpty::Padding(enc) => { - self.key.boolean_bitand(&no_more_matches, enc) - } + FheStringIsEmpty::Padding(enc) => self.boolean_bitand(&no_more_matches, enc), FheStringIsEmpty::NoPadding(clear) => { - let trivial = self.key.create_trivial_boolean_block(*clear); - self.key.boolean_bitand(&no_more_matches, &trivial) + let trivial = self.create_trivial_boolean_block(*clear); + self.boolean_bitand(&no_more_matches, &trivial) } } }, - || enc_n.map(|n| self.key.scalar_le_parallelized(n, current_iteration)), + || enc_n.map(|n| self.scalar_le_parallelized(n, current_iteration)), ); if let Some(exceeded) = enc_n_is_exceeded { - self.key - .boolean_bitor_assign(&mut no_more_matches, &exceeded); + self.boolean_bitor_assign(&mut no_more_matches, &exceeded); } no_more_matches @@ -216,10 +208,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern, UIntArg}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, from, to) = ("hello", "l", "r"); /// /// let enc_s = FheString::new(&ck, s, None); @@ -272,7 +266,7 @@ impl ServerKey { // We have to take into account that encrypted n could be 0 if let UIntArg::Enc(enc_n) = count { - let n_is_zero = self.key.scalar_eq_parallelized(enc_n.cipher(), 0); + let n_is_zero = self.scalar_eq_parallelized(enc_n.cipher(), 0); let mut re = self.conditional_string(&n_is_zero, result, to); @@ -292,8 +286,8 @@ impl ServerKey { } if let UIntArg::Enc(enc_n) = count { - let n_not_zero = self.key.scalar_ne_parallelized(enc_n.cipher(), 0); - let and_val = self.key.boolean_bitand(&n_not_zero, &val); + let n_not_zero = self.scalar_ne_parallelized(enc_n.cipher(), 0); + let and_val = self.boolean_bitand(&n_not_zero, &val); let mut re = self.conditional_string(&and_val, to.clone(), str); @@ -341,10 +335,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{ClearString, FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, from, to) = ("hi", "i", "o"); /// /// let enc_s = FheString::new(&ck, s, None); diff --git a/tfhe/src/strings/server_key/pattern/split/mod.rs b/tfhe/src/strings/server_key/pattern/split/mod.rs index 76c74fc40c..5c5cbef221 100644 --- a/tfhe/src/strings/server_key/pattern/split/mod.rs +++ b/tfhe/src/strings/server_key/pattern/split/mod.rs @@ -13,17 +13,17 @@ impl ServerKey { index: &RadixCiphertext, inclusive: bool, ) -> (FheString, FheString) { - let str_len = self.key.create_trivial_radix(str.len() as u32, 16); + let str_len = self.create_trivial_radix(str.len() as u32, 16); let trivial_or_enc_pat = match pat { GenericPattern::Clear(pat) => FheString::trivial(self, pat.str()), GenericPattern::Enc(pat) => pat.clone(), }; let (mut shift_right, real_pat_len) = rayon::join( - || self.key.sub_parallelized(&str_len, index), + || self.sub_parallelized(&str_len, index), || match self.len(&trivial_or_enc_pat) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.key.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), }, ); @@ -31,8 +31,7 @@ impl ServerKey { || { if inclusive { // Remove the real pattern length from the amount to shift - self.key - .sub_assign_parallelized(&mut shift_right, &real_pat_len); + self.sub_assign_parallelized(&mut shift_right, &real_pat_len); } let lhs = self.right_shift_chars(str, &shift_right); @@ -42,7 +41,7 @@ impl ServerKey { self.left_shift_chars(&lhs, &shift_right) }, || { - let shift_left = self.key.add_parallelized(&real_pat_len, index); + let shift_left = self.add_parallelized(&real_pat_len, index); self.left_shift_chars(str, &shift_left) }, @@ -74,10 +73,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = (" hello world", " "); /// let enc_s = FheString::new(&ck, s, None); /// let enc_pat = GenericPattern::Enc(FheString::new(&ck, pat, None)); @@ -86,7 +87,7 @@ impl ServerKey { /// /// let lhs_decrypted = ck.decrypt_ascii(&lhs); /// let rhs_decrypted = ck.decrypt_ascii(&rhs); - /// let split_occurred = ck.key().decrypt_bool(&split_occurred); + /// let split_occurred = ck.decrypt_bool(&split_occurred); /// /// assert_eq!(lhs_decrypted, " hello"); /// assert_eq!(rhs_decrypted, "world"); @@ -109,14 +110,14 @@ impl ServerKey { ( str.clone(), FheString::empty(), - self.key.create_trivial_boolean_block(true), + self.create_trivial_boolean_block(true), ) } else { // There's no match so we default to empty string and str ( FheString::empty(), str.clone(), - self.key.create_trivial_boolean_block(false), + self.create_trivial_boolean_block(false), ) }; } @@ -144,10 +145,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = (" hello world", " "); /// let enc_s = FheString::new(&ck, s, None); /// let enc_pat = GenericPattern::Enc(FheString::new(&ck, pat, None)); @@ -156,7 +159,7 @@ impl ServerKey { /// /// let lhs_decrypted = ck.decrypt_ascii(&lhs); /// let rhs_decrypted = ck.decrypt_ascii(&rhs); - /// let split_occurred = ck.key().decrypt_bool(&split_occurred); + /// let split_occurred = ck.decrypt_bool(&split_occurred); /// /// assert_eq!(lhs_decrypted, ""); /// assert_eq!(rhs_decrypted, "hello world"); @@ -179,14 +182,14 @@ impl ServerKey { ( FheString::empty(), str.clone(), - self.key.create_trivial_boolean_block(true), + self.create_trivial_boolean_block(true), ) } else { // There's no match so we default to empty string and str ( FheString::empty(), str.clone(), - self.key.create_trivial_boolean_block(false), + self.create_trivial_boolean_block(false), ) }; } @@ -210,19 +213,19 @@ impl ServerKey { ) -> SplitInternal { let mut max_counter = match self.len(str) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.key.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), }; - self.key.scalar_add_assign_parallelized(&mut max_counter, 1); + self.scalar_add_assign_parallelized(&mut max_counter, 1); SplitInternal { split_type, state: str.clone(), pat: pat.clone(), - prev_was_some: self.key.create_trivial_boolean_block(true), + prev_was_some: self.create_trivial_boolean_block(true), counter: 0, max_counter, - counter_lt_max: self.key.create_trivial_boolean_block(true), + counter_lt_max: self.create_trivial_boolean_block(true), } } @@ -240,12 +243,12 @@ impl ServerKey { let uint_not_0 = match &n { UIntArg::Clear(val) => { if *val != 0 { - self.key.create_trivial_boolean_block(true) + self.create_trivial_boolean_block(true) } else { - self.key.create_trivial_boolean_block(false) + self.create_trivial_boolean_block(false) } } - UIntArg::Enc(enc) => self.key.scalar_ne_parallelized(enc.cipher(), 0), + UIntArg::Enc(enc) => self.scalar_ne_parallelized(enc.cipher(), 0), }; let internal = self.split_internal(str, pat, split_type); @@ -270,17 +273,17 @@ impl ServerKey { let max_counter = match self.len(str) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.key.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), }; let internal = SplitInternal { split_type, state: str.clone(), pat: pat.clone(), - prev_was_some: self.key.create_trivial_boolean_block(true), + prev_was_some: self.create_trivial_boolean_block(true), counter: 0, max_counter, - counter_lt_max: self.key.create_trivial_boolean_block(true), + counter_lt_max: self.create_trivial_boolean_block(true), }; SplitNoTrailing { internal } @@ -293,7 +296,7 @@ impl ServerKey { let leading_empty_str = match self.is_empty(&prev_return.0) { FheStringIsEmpty::Padding(enc) => enc, - FheStringIsEmpty::NoPadding(clear) => self.key.create_trivial_boolean_block(clear), + FheStringIsEmpty::NoPadding(clear) => self.create_trivial_boolean_block(clear), }; SplitNoLeading { @@ -353,8 +356,8 @@ impl FheStringIterator for SplitInternal { } }, || match sk.is_empty(&trivial_or_enc_pat) { - FheStringIsEmpty::Padding(enc) => enc.into_radix(16, &sk.key), - FheStringIsEmpty::NoPadding(clear) => sk.key.create_trivial_radix(clear as u32, 16), + FheStringIsEmpty::Padding(enc) => enc.into_radix(16, sk), + FheStringIsEmpty::NoPadding(clear) => sk.create_trivial_radix(clear as u32, 16), }, ); @@ -367,9 +370,9 @@ impl FheStringIterator for SplitInternal { // start (or end in the rsplit case) if matches!(self.split_type, SplitType::RSplit) { - sk.key.sub_assign_parallelized(&mut index, &pat_is_empty); + sk.sub_assign_parallelized(&mut index, &pat_is_empty); } else { - sk.key.add_assign_parallelized(&mut index, &pat_is_empty); + sk.add_assign_parallelized(&mut index, &pat_is_empty); } } @@ -396,18 +399,14 @@ impl FheStringIterator for SplitInternal { // Even if there isn't match, we return Some if there was match in the previous next call, // as we are returning the remaining state "wrapped" in Some - sk.key - .boolean_bitor_assign(&mut is_some, &self.prev_was_some); + sk.boolean_bitor_assign(&mut is_some, &self.prev_was_some); // If pattern is empty, `is_some` is always true, so we make it false when we have reached // the last possible counter value - sk.key - .boolean_bitand_assign(&mut is_some, &self.counter_lt_max); + sk.boolean_bitand_assign(&mut is_some, &self.counter_lt_max); self.prev_was_some = current_is_some; - self.counter_lt_max = sk - .key - .scalar_gt_parallelized(&self.max_counter, self.counter); + self.counter_lt_max = sk.scalar_gt_parallelized(&self.max_counter, self.counter); self.counter += 1; @@ -422,8 +421,7 @@ impl FheStringIterator for SplitNInternal { let (mut result, mut is_some) = self.internal.next(sk); // This keeps the original `is_some` value unless we have exceeded n - sk.key - .boolean_bitand_assign(&mut is_some, &self.not_exceeded); + sk.boolean_bitand_assign(&mut is_some, &self.not_exceeded); // The moment counter is at least one less than n we return the remaining state, and make // `not_exceeded` false such that next calls are always None @@ -431,25 +429,24 @@ impl FheStringIterator for SplitNInternal { UIntArg::Clear(clear_n) => { if self.counter >= clear_n - 1 { result = state; - self.not_exceeded = sk.key.create_trivial_boolean_block(false); + self.not_exceeded = sk.create_trivial_boolean_block(false); } } UIntArg::Enc(enc_n) => { // Note that when `enc_n` is zero `n_minus_one` wraps to a very large number and so // `exceeded` will be false. Nonetheless the initial value of `not_exceeded` // was set to false in the n is zero case, so we return None - let n_minus_one = sk.key.scalar_sub_parallelized(enc_n.cipher(), 1); - let exceeded = sk.key.scalar_le_parallelized(&n_minus_one, self.counter); + let n_minus_one = sk.scalar_sub_parallelized(enc_n.cipher(), 1); + let exceeded = sk.scalar_le_parallelized(&n_minus_one, self.counter); rayon::join( || result = sk.conditional_string(&exceeded, state, &result), || { - let current_not_exceeded = sk.key.boolean_bitnot(&exceeded); + let current_not_exceeded = sk.boolean_bitnot(&exceeded); // If current is not exceeded we use the previous not_exceeded value, // or false if it's exceeded - sk.key - .boolean_bitand_assign(&mut self.not_exceeded, ¤t_not_exceeded); + sk.boolean_bitand_assign(&mut self.not_exceeded, ¤t_not_exceeded); }, ); } @@ -471,19 +468,18 @@ impl FheStringIterator for SplitNoTrailing { // string, we return None to remove it || match sk.is_empty(&result) { FheStringIsEmpty::Padding(enc) => enc, - FheStringIsEmpty::NoPadding(clear) => sk.key.create_trivial_boolean_block(clear), + FheStringIsEmpty::NoPadding(clear) => sk.create_trivial_boolean_block(clear), }, - || sk.key.boolean_bitnot(&self.internal.prev_was_some), + || sk.boolean_bitnot(&self.internal.prev_was_some), ); - let trailing_empty_str = sk.key.boolean_bitand(&result_is_empty, &prev_was_none); + let trailing_empty_str = sk.boolean_bitand(&result_is_empty, &prev_was_none); - let not_trailing_empty_str = sk.key.boolean_bitnot(&trailing_empty_str); + let not_trailing_empty_str = sk.boolean_bitnot(&trailing_empty_str); // If there's no empty trailing string we get the previous `is_some`, // else we get false (None) - sk.key - .boolean_bitand_assign(&mut is_some, ¬_trailing_empty_str); + sk.boolean_bitand_assign(&mut is_some, ¬_trailing_empty_str); (result, is_some) } @@ -504,18 +500,18 @@ impl FheStringIterator for SplitNoLeading { || { let (lhs, rhs) = rayon::join( // This is `is_some` if `leading_empty_str` is true, false otherwise - || sk.key.boolean_bitand(&self.leading_empty_str, &is_some), + || sk.boolean_bitand(&self.leading_empty_str, &is_some), // This is the flag from the previous next call if `leading_empty_str` is true, // false otherwise || { - sk.key.boolean_bitand( - &sk.key.boolean_bitnot(&self.leading_empty_str), + sk.boolean_bitand( + &sk.boolean_bitnot(&self.leading_empty_str), &self.prev_return.1, ) }, ); - sk.key.boolean_bitor(&lhs, &rhs) + sk.boolean_bitor(&lhs, &rhs) }, ); diff --git a/tfhe/src/strings/server_key/pattern/split/split_iters.rs b/tfhe/src/strings/server_key/pattern/split/split_iters.rs index d82f19a718..1cd4d98506 100644 --- a/tfhe/src/strings/server_key/pattern/split/split_iters.rs +++ b/tfhe/src/strings/server_key/pattern/split/split_iters.rs @@ -1,9 +1,9 @@ +use crate::integer::BooleanBlock; use crate::strings::ciphertext::{FheString, GenericPattern, UIntArg}; use crate::strings::server_key::pattern::split::{ SplitInternal, SplitNInternal, SplitNoLeading, SplitNoTrailing, SplitType, }; use crate::strings::server_key::{FheStringIterator, ServerKey}; -use crate::integer::BooleanBlock; pub struct RSplit { internal: SplitInternal, @@ -47,10 +47,13 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::{gen_keys, FheStringIterator}; + /// use tfhe::strings::server_key::FheStringIterator; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = ("hello ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -62,10 +65,10 @@ impl ServerKey { /// let (_, no_more_items) = split_iter.next(&sk); // Attempting to get a third item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); - /// let first_is_some = ck.key().decrypt_bool(&first_is_some); + /// let first_is_some = ck.decrypt_bool(&first_is_some); /// let second_decrypted = ck.decrypt_ascii(&second_item); - /// let second_is_some = ck.key().decrypt_bool(&second_is_some); - /// let no_more_items = ck.key().decrypt_bool(&no_more_items); + /// let second_is_some = ck.decrypt_bool(&second_is_some); + /// let no_more_items = ck.decrypt_bool(&no_more_items); /// /// assert_eq!(first_decrypted, "hello"); /// assert!(first_is_some); // There is a first item @@ -92,10 +95,13 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::{gen_keys, FheStringIterator}; + /// use tfhe::strings::server_key::FheStringIterator; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = ("hello ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -107,10 +113,10 @@ impl ServerKey { /// let (_, no_more_items) = rsplit_iter.next(&sk); // Attempting to get a third item /// /// let last_decrypted = ck.decrypt_ascii(&last_item); - /// let last_is_some = ck.key().decrypt_bool(&last_is_some); + /// let last_is_some = ck.decrypt_bool(&last_is_some); /// let second_last_decrypted = ck.decrypt_ascii(&second_last_item); - /// let second_last_is_some = ck.key().decrypt_bool(&second_last_is_some); - /// let no_more_items = ck.key().decrypt_bool(&no_more_items); + /// let second_last_is_some = ck.decrypt_bool(&second_last_is_some); + /// let no_more_items = ck.decrypt_bool(&no_more_items); /// /// assert_eq!(last_decrypted, ""); /// assert!(last_is_some); // The last item is empty @@ -138,10 +144,13 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern, UIntArg}; - /// use tfhe::strings::server_key::{gen_keys, FheStringIterator}; + /// use tfhe::strings::server_key::FheStringIterator; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = ("hello world", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -154,8 +163,8 @@ impl ServerKey { /// let (_, no_more_items) = splitn_iter.next(&sk); // Attempting to get a second item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); - /// let first_is_some = ck.key().decrypt_bool(&first_is_some); - /// let no_more_items = ck.key().decrypt_bool(&no_more_items); + /// let first_is_some = ck.decrypt_bool(&first_is_some); + /// let no_more_items = ck.decrypt_bool(&no_more_items); /// /// // We get the whole str as n is 1 /// assert_eq!(first_decrypted, "hello world"); @@ -190,10 +199,13 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern, UIntArg}; - /// use tfhe::strings::server_key::{gen_keys, FheStringIterator}; + /// use tfhe::strings::server_key::FheStringIterator; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = ("hello world", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -206,8 +218,8 @@ impl ServerKey { /// let (_, no_more_items) = rsplitn_iter.next(&sk); // Attempting to get a second item /// /// let last_decrypted = ck.decrypt_ascii(&last_item); - /// let last_is_some = ck.key().decrypt_bool(&last_is_some); - /// let no_more_items = ck.key().decrypt_bool(&no_more_items); + /// let last_is_some = ck.decrypt_bool(&last_is_some); + /// let no_more_items = ck.decrypt_bool(&no_more_items); /// /// // We get the whole str as n is 1 /// assert_eq!(last_decrypted, "hello world"); @@ -240,10 +252,13 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::{gen_keys, FheStringIterator}; + /// use tfhe::strings::server_key::FheStringIterator; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = ("hello world ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -255,10 +270,10 @@ impl ServerKey { /// let (_, no_more_items) = split_terminator_iter.next(&sk); // Attempting to get a third item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); - /// let first_is_some = ck.key().decrypt_bool(&first_is_some); + /// let first_is_some = ck.decrypt_bool(&first_is_some); /// let second_decrypted = ck.decrypt_ascii(&second_item); - /// let second_is_some = ck.key().decrypt_bool(&second_is_some); - /// let no_more_items = ck.key().decrypt_bool(&no_more_items); + /// let second_is_some = ck.decrypt_bool(&second_is_some); + /// let no_more_items = ck.decrypt_bool(&no_more_items); /// /// assert_eq!(first_decrypted, "hello"); /// assert!(first_is_some); // There is a first item @@ -288,10 +303,13 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::{gen_keys, FheStringIterator}; + /// use tfhe::strings::server_key::FheStringIterator; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = ("hello world ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -303,10 +321,10 @@ impl ServerKey { /// let (_, no_more_items) = rsplit_terminator_iter.next(&sk); // Attempting to get a third item /// /// let last_decrypted = ck.decrypt_ascii(&last_item); - /// let last_is_some = ck.key().decrypt_bool(&last_is_some); + /// let last_is_some = ck.decrypt_bool(&last_is_some); /// let second_last_decrypted = ck.decrypt_ascii(&second_last_item); - /// let second_last_is_some = ck.key().decrypt_bool(&second_last_is_some); - /// let no_more_items = ck.key().decrypt_bool(&no_more_items); + /// let second_last_is_some = ck.decrypt_bool(&second_last_is_some); + /// let no_more_items = ck.decrypt_bool(&no_more_items); /// /// assert_eq!(last_decrypted, "world"); /// assert!(last_is_some); // The last item is "world" instead of "" @@ -335,10 +353,13 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{FheString, GenericPattern}; - /// use tfhe::strings::server_key::{gen_keys, FheStringIterator}; + /// use tfhe::strings::server_key::FheStringIterator; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, pat) = ("hello world ", " "); /// /// let enc_s = FheString::new(&ck, s, None); @@ -350,10 +371,10 @@ impl ServerKey { /// let (_, no_more_items) = split_inclusive_iter.next(&sk); // Attempting to get a third item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); - /// let first_is_some = ck.key().decrypt_bool(&first_is_some); + /// let first_is_some = ck.decrypt_bool(&first_is_some); /// let second_decrypted = ck.decrypt_ascii(&second_item); - /// let second_is_some = ck.key().decrypt_bool(&second_is_some); - /// let no_more_items = ck.key().decrypt_bool(&no_more_items); + /// let second_is_some = ck.decrypt_bool(&second_is_some); + /// let no_more_items = ck.decrypt_bool(&no_more_items); /// /// assert_eq!(first_decrypted, "hello "); /// assert!(first_is_some); // The first item includes the delimiter diff --git a/tfhe/src/strings/server_key/pattern/strip.rs b/tfhe/src/strings/server_key/pattern/strip.rs index cb5c467e85..5fead6e3f8 100644 --- a/tfhe/src/strings/server_key/pattern/strip.rs +++ b/tfhe/src/strings/server_key/pattern/strip.rs @@ -14,7 +14,7 @@ impl ServerKey { str_pat: (CharIter, CharIter), iter: Range, ) -> BooleanBlock { - let mut result = self.key.create_trivial_boolean_block(false); + let mut result = self.create_trivial_boolean_block(false); let (str, pat) = str_pat; let pat_len = pat.len(); @@ -23,10 +23,10 @@ impl ServerKey { for start in iter { let is_matched = self.asciis_eq(str.iter().copied().skip(start), pat.iter().copied()); - let mut mask = is_matched.clone().into_radix(4, &self.key); + let mut mask = is_matched.clone().into_radix(4, self); // If mask == 0u8, it will now be 255u8. If it was 1u8, it will now be 0u8 - self.key.scalar_sub_assign_parallelized(&mut mask, 1); + self.scalar_sub_assign_parallelized(&mut mask, 1); let mutate_chars = strip_str.chars_mut().par_iter_mut().skip(start).take( if start + pat_len < str_len { @@ -39,12 +39,11 @@ impl ServerKey { rayon::join( || { mutate_chars.for_each(|char| { - self.key - .bitand_assign_parallelized(char.ciphertext_mut(), &mask); + self.bitand_assign_parallelized(char.ciphertext_mut(), &mask); }); }, // One of the possible values of pat must match the str - || self.key.boolean_bitor_assign(&mut result, &is_matched), + || self.boolean_bitor_assign(&mut result, &is_matched), ); } @@ -57,7 +56,7 @@ impl ServerKey { str_pat: (CharIter, &str), iter: Range, ) -> BooleanBlock { - let mut result = self.key.create_trivial_boolean_block(false); + let mut result = self.create_trivial_boolean_block(false); let (str, pat) = str_pat; let pat_len = pat.len(); @@ -65,10 +64,10 @@ impl ServerKey { for start in iter { let is_matched = self.clear_asciis_eq(str.iter().copied().skip(start), pat); - let mut mask = is_matched.clone().into_radix(4, &self.key); + let mut mask = is_matched.clone().into_radix(4, self); // If mask == 0u8, it will now be 255u8. If it was 1u8, it will now be 0u8 - self.key.scalar_sub_assign_parallelized(&mut mask, 1); + self.scalar_sub_assign_parallelized(&mut mask, 1); let mutate_chars = strip_str.chars_mut().par_iter_mut().skip(start).take( if start + pat_len < str_len { @@ -81,12 +80,11 @@ impl ServerKey { rayon::join( || { mutate_chars.for_each(|char| { - self.key - .bitand_assign_parallelized(char.ciphertext_mut(), &mask); + self.bitand_assign_parallelized(char.ciphertext_mut(), &mask); }); }, // One of the possible values of pat must match the str - || self.key.boolean_bitor_assign(&mut result, &is_matched), + || self.boolean_bitor_assign(&mut result, &is_matched), ); } @@ -106,10 +104,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{ClearString, FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, prefix, not_prefix) = ("hello world", "hello", "world"); /// /// let enc_s = FheString::new(&ck, s, None); @@ -118,11 +118,11 @@ impl ServerKey { /// /// let (result, found) = sk.strip_prefix(&enc_s, &enc_prefix); /// let stripped = ck.decrypt_ascii(&result); - /// let found = ck.key().decrypt_bool(&found); + /// let found = ck.decrypt_bool(&found); /// /// let (result_no_match, not_found) = sk.strip_prefix(&enc_s, &clear_not_prefix); /// let not_stripped = ck.decrypt_ascii(&result_no_match); - /// let not_found = ck.key().decrypt_bool(¬_found); + /// let not_found = ck.decrypt_bool(¬_found); /// /// assert!(found); /// assert_eq!(stripped, " world"); // "hello" is stripped from "hello world" @@ -139,7 +139,7 @@ impl ServerKey { match self.length_checks(str, &trivial_or_enc_pat) { // If IsMatch is Clear we return the same string (a true means the pattern is empty) - IsMatch::Clear(bool) => return (result, self.key.create_trivial_boolean_block(bool)), + IsMatch::Clear(bool) => return (result, self.create_trivial_boolean_block(bool)), // If IsMatch is Cipher it means str is empty so in any case we return the same string IsMatch::Cipher(val) => return (result, val), @@ -150,16 +150,16 @@ impl ServerKey { || self.starts_with(str, pat), || match self.len(&trivial_or_enc_pat) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.key.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), }, ); // If there's match we shift the str left by `real_pat_len` (removing the prefix and adding // nulls at the end), else we shift it left by 0 - let shift_left = self.key.if_then_else_parallelized( + let shift_left = self.if_then_else_parallelized( &starts_with, &real_pat_len, - &self.key.create_trivial_zero_radix(16), + &self.create_trivial_zero_radix(16), ); result = self.left_shift_chars(str, &shift_left); @@ -189,10 +189,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::{ClearString, FheString, GenericPattern}; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let (s, suffix, not_suffix) = ("hello world", "world", "hello"); /// /// let enc_s = FheString::new(&ck, s, None); @@ -201,11 +203,11 @@ impl ServerKey { /// /// let (result, found) = sk.strip_suffix(&enc_s, &enc_suffix); /// let stripped = ck.decrypt_ascii(&result); - /// let found = ck.key().decrypt_bool(&found); + /// let found = ck.decrypt_bool(&found); /// /// let (result_no_match, not_found) = sk.strip_suffix(&enc_s, &clear_not_suffix); /// let not_stripped = ck.decrypt_ascii(&result_no_match); - /// let not_found = ck.key().decrypt_bool(¬_found); + /// let not_found = ck.decrypt_bool(¬_found); /// /// assert!(found); /// assert_eq!(stripped, "hello "); // "world" is stripped from "hello world" @@ -223,7 +225,7 @@ impl ServerKey { match self.length_checks(str, &trivial_or_enc_pat) { // If IsMatch is Clear we return the same string (a true means the pattern is empty) - IsMatch::Clear(bool) => return (result, self.key.create_trivial_boolean_block(bool)), + IsMatch::Clear(bool) => return (result, self.create_trivial_boolean_block(bool)), // If IsMatch is Cipher it means str is empty so in any case we return the same string IsMatch::Cipher(val) => return (result, val), diff --git a/tfhe/src/strings/server_key/trim.rs b/tfhe/src/strings/server_key/trim.rs index 34a5348b6f..6fd7c83234 100644 --- a/tfhe/src/strings/server_key/trim.rs +++ b/tfhe/src/strings/server_key/trim.rs @@ -14,10 +14,7 @@ impl FheStringIterator for SplitAsciiWhitespace { let str_len = self.state.len(); if str_len == 0 || (self.state.is_padded() && str_len == 1) { - return ( - FheString::empty(), - sk.key.create_trivial_boolean_block(false), - ); + return (FheString::empty(), sk.create_trivial_boolean_block(false)); } // If we aren't in the first next call `current_mask` is some @@ -34,7 +31,7 @@ impl FheStringIterator for SplitAsciiWhitespace { // If state after trim_start is empty it means the remaining string was either // empty or only whitespace. Hence, there are no more elements to return if let FheStringIsEmpty::Padding(val) = sk.is_empty(&state_after_trim) { - sk.key.boolean_bitnot(&val) + sk.boolean_bitnot(&val) } else { panic!("Empty str case was handled so 'state_after_trim' is padded") } @@ -49,16 +46,16 @@ impl SplitAsciiWhitespace { let mut mask = self.state.clone(); let mut result = self.state.clone(); - let mut prev_was_not = sk.key.create_trivial_boolean_block(true); + let mut prev_was_not = sk.create_trivial_boolean_block(true); for char in mask.chars_mut().iter_mut() { let mut is_not_ws = sk.is_not_whitespace(char); - sk.key.boolean_bitand_assign(&mut is_not_ws, &prev_was_not); + sk.boolean_bitand_assign(&mut is_not_ws, &prev_was_not); - let mut mask_u8 = is_not_ws.clone().into_radix(4, &sk.key); + let mut mask_u8 = is_not_ws.clone().into_radix(4, sk); // 0u8 is kept the same, but 1u8 is transformed into 255u8 - sk.key.scalar_sub_assign_parallelized(&mut mask_u8, 1); - sk.key.bitnot_assign(&mut mask_u8); + sk.scalar_sub_assign_parallelized(&mut mask_u8, 1); + sk.bitnot_assign(&mut mask_u8); *char.ciphertext_mut() = mask_u8; @@ -71,8 +68,7 @@ impl SplitAsciiWhitespace { .par_iter_mut() .zip(mask.chars().par_iter()) .for_each(|(char, mask_u8)| { - sk.key - .bitand_assign_parallelized(char.ciphertext_mut(), mask_u8.ciphertext()); + sk.bitand_assign_parallelized(char.ciphertext_mut(), mask_u8.ciphertext()); }); self.current_mask = Some(mask); @@ -84,16 +80,13 @@ impl SplitAsciiWhitespace { fn remaining_string(&mut self, sk: &ServerKey) { let mask = self.current_mask.as_ref().unwrap(); - let mut number_of_trues: RadixCiphertext = sk.key.create_trivial_zero_radix(16); + let mut number_of_trues: RadixCiphertext = sk.create_trivial_zero_radix(16); for mask_u8 in mask.chars() { - let is_true = sk.key.scalar_eq_parallelized(mask_u8.ciphertext(), 255u8); + let is_true = sk.scalar_eq_parallelized(mask_u8.ciphertext(), 255u8); let num_blocks = number_of_trues.blocks().len(); - sk.key.add_assign_parallelized( - &mut number_of_trues, - &is_true.into_radix(num_blocks, &sk.key), - ); + sk.add_assign_parallelized(&mut number_of_trues, &is_true.into_radix(num_blocks, sk)); } let padded = self.state.is_padded(); @@ -119,39 +112,33 @@ impl ServerKey { rayon::join( || { rayon::join( - || self.key.scalar_eq_parallelized(char.ciphertext(), 0x20u8), - || self.key.scalar_eq_parallelized(char.ciphertext(), 0x09u8), + || self.scalar_eq_parallelized(char.ciphertext(), 0x20u8), + || self.scalar_eq_parallelized(char.ciphertext(), 0x09u8), ) }, || { rayon::join( - || self.key.scalar_eq_parallelized(char.ciphertext(), 0x0Au8), - || self.key.scalar_eq_parallelized(char.ciphertext(), 0x0Cu8), + || self.scalar_eq_parallelized(char.ciphertext(), 0x0Au8), + || self.scalar_eq_parallelized(char.ciphertext(), 0x0Cu8), ) }, ) }, || { rayon::join( - || self.key.scalar_eq_parallelized(char.ciphertext(), 0x0Du8), - || { - or_null - .then_some(self.key.scalar_eq_parallelized(char.ciphertext(), 0u8)) - }, + || self.scalar_eq_parallelized(char.ciphertext(), 0x0Du8), + || or_null.then_some(self.scalar_eq_parallelized(char.ciphertext(), 0u8)), ) }, ); - let mut is_whitespace = self.key.boolean_bitor(&is_space, &is_tab); - self.key - .boolean_bitor_assign(&mut is_whitespace, &is_new_line); - self.key - .boolean_bitor_assign(&mut is_whitespace, &is_form_feed); - self.key - .boolean_bitor_assign(&mut is_whitespace, &is_carriage_return); + let mut is_whitespace = self.boolean_bitor(&is_space, &is_tab); + self.boolean_bitor_assign(&mut is_whitespace, &is_new_line); + self.boolean_bitor_assign(&mut is_whitespace, &is_form_feed); + self.boolean_bitor_assign(&mut is_whitespace, &is_carriage_return); if let Some(is_null) = op_is_null { - self.key.boolean_bitor_assign(&mut is_whitespace, &is_null); + self.boolean_bitor_assign(&mut is_whitespace, &is_null); } is_whitespace @@ -160,22 +147,21 @@ impl ServerKey { fn is_not_whitespace(&self, char: &FheAsciiChar) -> BooleanBlock { let result = self.is_whitespace(char, false); - self.key.boolean_bitnot(&result) + self.boolean_bitnot(&result) } fn compare_and_trim<'a, I>(&self, strip_str: I, starts_with_null: bool) where I: Iterator, { - let mut prev_was_ws = self.key.create_trivial_boolean_block(true); + let mut prev_was_ws = self.create_trivial_boolean_block(true); for char in strip_str { let mut is_whitespace = self.is_whitespace(char, starts_with_null); - self.key - .boolean_bitand_assign(&mut is_whitespace, &prev_was_ws); + self.boolean_bitand_assign(&mut is_whitespace, &prev_was_ws); - *char.ciphertext_mut() = self.key.if_then_else_parallelized( + *char.ciphertext_mut() = self.if_then_else_parallelized( &is_whitespace, - &self.key.create_trivial_zero_radix(4), + &self.create_trivial_zero_radix(4), char.ciphertext(), ); @@ -189,10 +175,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let s = " hello world"; /// /// let enc_s = FheString::new(&ck, s, None); @@ -218,12 +206,10 @@ impl ServerKey { if let FheStringLen::Padding(len_after_trim) = self.len(&result) { let original_str_len = match self.len(str) { FheStringLen::Padding(enc_val) => enc_val, - FheStringLen::NoPadding(val) => self.key.create_trivial_radix(val as u32, 16), + FheStringLen::NoPadding(val) => self.create_trivial_radix(val as u32, 16), }; - let shift_left = self - .key - .sub_parallelized(&original_str_len, &len_after_trim); + let shift_left = self.sub_parallelized(&original_str_len, &len_after_trim); result = self.left_shift_chars(&result, &shift_left); } @@ -245,10 +231,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let s = "hello world "; /// /// let enc_s = FheString::new(&ck, s, None); @@ -285,10 +273,12 @@ impl ServerKey { /// # Examples /// /// ```rust + /// use tfhe::integer::{ClientKey, ServerKey}; + /// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; - /// use tfhe::strings::server_key::gen_keys; /// - /// let (ck, sk) = gen_keys(); + /// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// let sk = ServerKey::new_radix_server_key(&ck); /// let s = " hello world "; /// /// let enc_s = FheString::new(&ck, s, None); @@ -320,10 +310,13 @@ impl ServerKey { /// # Examples /// /// ```rust +/// use tfhe::integer::{ClientKey, ServerKey}; +/// use tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_2; /// use tfhe::strings::ciphertext::FheString; -/// use tfhe::strings::server_key::{gen_keys, split_ascii_whitespace, FheStringIterator}; +/// use tfhe::strings::server_key::{split_ascii_whitespace, FheStringIterator}; /// -/// let (ck, sk) = gen_keys(); +/// let ck = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); +/// let sk = ServerKey::new_radix_server_key(&ck); /// let s = "hello \t\nworld "; /// /// let enc_s = FheString::new(&ck, s, None); @@ -334,11 +327,11 @@ impl ServerKey { /// let (empty, no_more_items) = whitespace_iter.next(&sk); // Attempting to get a third item /// /// let first_decrypted = ck.decrypt_ascii(&first_item); -/// let first_is_some = ck.key().decrypt_bool(&first_is_some); +/// let first_is_some = ck.decrypt_bool(&first_is_some); /// let second_decrypted = ck.decrypt_ascii(&second_item); -/// let second_is_some = ck.key().decrypt_bool(&second_is_some); +/// let second_is_some = ck.decrypt_bool(&second_is_some); /// let empty = ck.decrypt_ascii(&empty); -/// let no_more_items = ck.key().decrypt_bool(&no_more_items); +/// let no_more_items = ck.decrypt_bool(&no_more_items); /// /// assert_eq!(first_decrypted, "hello"); /// assert!(first_is_some);