From f24fa62331edbedb1425560dbdb45f5f7b807d79 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 3 Dec 2024 18:27:21 +0100 Subject: [PATCH] refactor!: use strong types for outputs of DispersionParameters trait fns --- .../lwe_encryption_noise.rs | 11 +- .../noise_distribution/lwe_keyswitch_noise.rs | 2 +- tfhe/src/core_crypto/commons/dispersion.rs | 180 ++++++++++++------ .../commons/math/random/gaussian.rs | 2 +- .../core_crypto/commons/math/random/mod.rs | 4 +- tfhe/src/core_crypto/commons/mod.rs | 4 +- 6 files changed, 130 insertions(+), 73 deletions(-) diff --git a/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_encryption_noise.rs b/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_encryption_noise.rs index 21b6e89638..2075492a42 100644 --- a/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_encryption_noise.rs +++ b/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_encryption_noise.rs @@ -19,7 +19,7 @@ fn lwe_encrypt_decrypt_noise_distribution_custom_mod Variance { let input_variance = input_noise.get_variance(); - Variance(input_variance * (lwe_dimension.to_lwe_size().0 as f64)) + Variance(input_variance.0 * (lwe_dimension.to_lwe_size().0 as f64)) } #[test] @@ -104,7 +104,8 @@ fn test_variance_increase_cpk_formula() { ); assert!( - (predicted_variance.get_standard_dev().log2() - 44.000704097196405f64).abs() < f64::EPSILON + (predicted_variance.get_standard_dev().0.log2() - 44.000704097196405f64).abs() + < f64::EPSILON ); } @@ -119,7 +120,7 @@ fn lwe_compact_public_encrypt_noise_distribution_custom_mod< let message_modulus_log = params.message_modulus_log; let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus); - let glwe_variance = Variance(glwe_noise_distribution.gaussian_std_dev().get_variance()); + let glwe_variance = glwe_noise_distribution.gaussian_std_dev().get_variance(); let expected_variance = lwe_compact_public_key_encryption_expected_variance(glwe_variance, lwe_dimension); @@ -208,7 +209,7 @@ fn random_noise_roundtrip>( assert!(matches!(noise, DynamicDistribution::Gaussian(_))); - let expected_variance = Variance(noise.gaussian_std_dev().get_variance()); + let expected_variance = noise.gaussian_std_dev().get_variance(); let num_outputs = 100_000; diff --git a/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_keyswitch_noise.rs b/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_keyswitch_noise.rs index 0be5c3b676..80e0d501bd 100644 --- a/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_keyswitch_noise.rs +++ b/tfhe/src/core_crypto/algorithms/test/noise_distribution/lwe_keyswitch_noise.rs @@ -33,7 +33,7 @@ fn lwe_encrypt_ks_decrypt_noise_distribution_custom_mod f64; + fn get_standard_dev(&self) -> StandardDev; /// Return the variance of the distribution, i.e. $\sigma^2 = 2^{2p}$. - fn get_variance(&self) -> f64; + fn get_variance(&self) -> Variance; /// Return base 2 logarithm of the standard deviation of the distribution, i.e. /// $\log\_2(\sigma)=p$ - fn get_log_standard_dev(&self) -> f64; + fn get_log_standard_dev(&self) -> LogStandardDev; /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$. - fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64; + fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev; /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$. - fn get_modular_variance(&self, log2_modulus: u32) -> f64; + fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance; /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$. - fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64; + fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev; +} + +fn log2_modulus_to_modulus(log2_modulus: u32) -> f64 { + 2.0f64.powi(log2_modulus as i32) } /// A distribution parameter that uses the base-2 logarithm of the standard deviation as @@ -49,22 +53,31 @@ pub trait DispersionParameter: Copy { /// ```rust /// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, LogStandardDev}; /// let params = LogStandardDev::from_log_standard_dev(-25.); -/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.)); -/// assert_eq!(params.get_log_standard_dev(), -25.); -/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2)); -/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.)); -/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.); +/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.)); +/// assert_eq!(params.get_log_standard_dev().0, -25.); +/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2)); /// assert_eq!( -/// params.get_modular_variance(32), +/// params.get_modular_standard_dev(32).value, +/// 2_f64.powf(32. - 25.) +/// ); +/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.); +/// assert_eq!( +/// params.get_modular_variance(32).value, /// 2_f64.powf(32. - 25.).powi(2) /// ); /// /// let modular_params = LogStandardDev::from_modular_log_standard_dev(22., 32); -/// assert_eq!(modular_params.get_standard_dev(), 2_f64.powf(-10.)); +/// assert_eq!(modular_params.get_standard_dev().0, 2_f64.powf(-10.)); /// ``` #[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] pub struct LogStandardDev(pub f64); +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] +pub struct ModularLogStandardDev { + pub value: f64, + pub modulus: f64, +} + impl LogStandardDev { pub fn from_log_standard_dev(log_std: f64) -> Self { Self(log_std) @@ -76,23 +89,32 @@ impl LogStandardDev { } impl DispersionParameter for LogStandardDev { - fn get_standard_dev(&self) -> f64 { - f64::powf(2., self.0) + fn get_standard_dev(&self) -> StandardDev { + StandardDev(f64::powf(2., self.0)) } - fn get_variance(&self) -> f64 { - f64::powf(2., self.0 * 2.) + fn get_variance(&self) -> Variance { + Variance(f64::powf(2., self.0 * 2.)) } - fn get_log_standard_dev(&self) -> f64 { - self.0 + fn get_log_standard_dev(&self) -> Self { + Self(self.0) } - fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 { - f64::powf(2., log2_modulus as f64 + self.0) + fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev { + ModularStandardDev { + value: f64::powf(2., log2_modulus as f64 + self.0), + modulus: log2_modulus_to_modulus(log2_modulus), + } } - fn get_modular_variance(&self, log2_modulus: u32) -> f64 { - f64::powf(2., (log2_modulus as f64 + self.0) * 2.) + fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance { + ModularVariance { + value: f64::powf(2., (log2_modulus as f64 + self.0) * 2.), + modulus: log2_modulus_to_modulus(log2_modulus), + } } - fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 { - log2_modulus as f64 + self.0 + fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev { + ModularLogStandardDev { + value: log2_modulus as f64 + self.0, + modulus: log2_modulus_to_modulus(log2_modulus), + } } } @@ -103,13 +125,16 @@ impl DispersionParameter for LogStandardDev { /// ```rust /// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, StandardDev}; /// let params = StandardDev::from_standard_dev(2_f64.powf(-25.)); -/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.)); -/// assert_eq!(params.get_log_standard_dev(), -25.); -/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2)); -/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.)); -/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.); +/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.)); +/// assert_eq!(params.get_log_standard_dev().0, -25.); +/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2)); +/// assert_eq!( +/// params.get_modular_standard_dev(32).value, +/// 2_f64.powf(32. - 25.) +/// ); +/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.); /// assert_eq!( -/// params.get_modular_variance(32), +/// params.get_modular_variance(32).value, /// 2_f64.powf(32. - 25.).powi(2) /// ); /// ``` @@ -117,6 +142,12 @@ impl DispersionParameter for LogStandardDev { #[versionize(StandardDevVersions)] pub struct StandardDev(pub f64); +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] +pub struct ModularStandardDev { + pub value: f64, + pub modulus: f64, +} + impl StandardDev { pub fn from_standard_dev(std: f64) -> Self { Self(std) @@ -128,23 +159,32 @@ impl StandardDev { } impl DispersionParameter for StandardDev { - fn get_standard_dev(&self) -> f64 { - self.0 + fn get_standard_dev(&self) -> Self { + Self(self.0) } - fn get_variance(&self) -> f64 { - self.0.powi(2) + fn get_variance(&self) -> Variance { + Variance(self.0.powi(2)) } - fn get_log_standard_dev(&self) -> f64 { - self.0.log2() + fn get_log_standard_dev(&self) -> LogStandardDev { + LogStandardDev(self.0.log2()) } - fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 { - 2_f64.powf(log2_modulus as f64 + self.0.log2()) + fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev { + ModularStandardDev { + value: 2_f64.powf(log2_modulus as f64 + self.0.log2()), + modulus: log2_modulus_to_modulus(log2_modulus), + } } - fn get_modular_variance(&self, log2_modulus: u32) -> f64 { - 2_f64.powf(2. * (log2_modulus as f64 + self.0.log2())) + fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance { + ModularVariance { + value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.log2())), + modulus: log2_modulus_to_modulus(log2_modulus), + } } - fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 { - log2_modulus as f64 + self.0.log2() + fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev { + ModularLogStandardDev { + value: log2_modulus as f64 + self.0.log2(), + modulus: log2_modulus_to_modulus(log2_modulus), + } } } @@ -155,19 +195,28 @@ impl DispersionParameter for StandardDev { /// ```rust /// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, Variance}; /// let params = Variance::from_variance(2_f64.powi(-50)); -/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.)); -/// assert_eq!(params.get_log_standard_dev(), -25.); -/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2)); -/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.)); -/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.); +/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.)); +/// assert_eq!(params.get_log_standard_dev().0, -25.); +/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2)); +/// assert_eq!( +/// params.get_modular_standard_dev(32).value, +/// 2_f64.powf(32. - 25.) +/// ); +/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.); /// assert_eq!( -/// params.get_modular_variance(32), +/// params.get_modular_variance(32).value, /// 2_f64.powf(32. - 25.).powi(2) /// ); /// ``` #[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] pub struct Variance(pub f64); +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] +pub struct ModularVariance { + pub value: f64, + pub modulus: f64, +} + impl Variance { pub fn from_variance(var: f64) -> Self { Self(var) @@ -179,22 +228,31 @@ impl Variance { } impl DispersionParameter for Variance { - fn get_standard_dev(&self) -> f64 { - self.0.sqrt() + fn get_standard_dev(&self) -> StandardDev { + StandardDev(self.0.sqrt()) } - fn get_variance(&self) -> f64 { - self.0 + fn get_variance(&self) -> Self { + Self(self.0) } - fn get_log_standard_dev(&self) -> f64 { - self.0.sqrt().log2() + fn get_log_standard_dev(&self) -> LogStandardDev { + LogStandardDev(self.0.sqrt().log2()) } - fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 { - 2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2()) + fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev { + ModularStandardDev { + value: 2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2()), + modulus: log2_modulus_to_modulus(log2_modulus), + } } - fn get_modular_variance(&self, log2_modulus: u32) -> f64 { - 2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2())) + fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance { + ModularVariance { + value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2())), + modulus: log2_modulus_to_modulus(log2_modulus), + } } - fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 { - log2_modulus as f64 + self.0.sqrt().log2() + fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev { + ModularLogStandardDev { + value: log2_modulus as f64 + self.0.sqrt().log2(), + modulus: log2_modulus_to_modulus(log2_modulus), + } } } diff --git a/tfhe/src/core_crypto/commons/math/random/gaussian.rs b/tfhe/src/core_crypto/commons/math/random/gaussian.rs index 550a1c9c52..d8c2823abc 100644 --- a/tfhe/src/core_crypto/commons/math/random/gaussian.rs +++ b/tfhe/src/core_crypto/commons/math/random/gaussian.rs @@ -25,7 +25,7 @@ impl Gaussian { pub fn from_dispersion_parameter(dispersion: impl DispersionParameter, mean: f64) -> Self { Self { - std: dispersion.get_standard_dev(), + std: dispersion.get_standard_dev().0, mean, } } diff --git a/tfhe/src/core_crypto/commons/math/random/mod.rs b/tfhe/src/core_crypto/commons/math/random/mod.rs index 0f9e0ed805..4999a26554 100644 --- a/tfhe/src/core_crypto/commons/math/random/mod.rs +++ b/tfhe/src/core_crypto/commons/math/random/mod.rs @@ -253,9 +253,7 @@ impl DynamicDistribution { #[track_caller] pub fn gaussian_variance(&self) -> Variance { match self { - Self::Gaussian(gaussian) => { - Variance(StandardDev::from_standard_dev(gaussian.std).get_variance()) - } + Self::Gaussian(gaussian) => StandardDev::from_standard_dev(gaussian.std).get_variance(), Self::TUniform(_) => { panic!("Tried to get gaussian variance from a non gaussian distribution") } diff --git a/tfhe/src/core_crypto/commons/mod.rs b/tfhe/src/core_crypto/commons/mod.rs index 7e82df830d..8ceb64d110 100644 --- a/tfhe/src/core_crypto/commons/mod.rs +++ b/tfhe/src/core_crypto/commons/mod.rs @@ -93,11 +93,11 @@ pub mod test_tools { { for (x, y) in first.as_ref().iter().zip(second.as_ref().iter()) { println!("{:?}, {:?}", *x, *y); - println!("{}", dist.get_standard_dev()); + println!("{:?}", dist.get_standard_dev()); let distance: f64 = modular_distance(*x, *y).cast_into(); let torus_distance = distance / 2_f64.powi(Element::BITS as i32); assert!( - torus_distance <= 5. * dist.get_standard_dev(), + torus_distance <= 5. * dist.get_standard_dev().0, "{x} != {y} " ); }