Skip to content

Commit

Permalink
fix(shortint): fix compression encoding change not being taken into a…
Browse files Browse the repository at this point in the history
…ccount

- this maps better to what was optimized and will dramatically diminish the
pfail as we now have 2 more bits for the LUT redundancy
  • Loading branch information
IceTDrinker committed Dec 13, 2024
1 parent bdbec55 commit 8054368
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 16 deletions.
37 changes: 32 additions & 5 deletions tfhe/src/shortint/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,33 @@ pub(crate) fn fill_accumulator<F, C>(
carry_modulus: CarryModulus,
f: F,
) -> u64
where
C: ContainerMut<Element = u64>,
F: Fn(u64) -> u64,
{
fill_accumulator_with_encoding(
accumulator,
polynomial_size,
glwe_size,
message_modulus,
carry_modulus,
message_modulus,
carry_modulus,
f,
)
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn fill_accumulator_with_encoding<F, C>(
accumulator: &mut GlweCiphertext<C>,
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
input_message_modulus: MessageModulus,
input_carry_modulus: CarryModulus,
output_message_modulus: MessageModulus,
output_carry_modulus: CarryModulus,
f: F,
) -> u64
where
C: ContainerMut<Element = u64>,
F: Fn(u64) -> u64,
Expand All @@ -97,25 +124,25 @@ where
accumulator_view.get_mut_mask().as_mut().fill(0);

// Modulus of the msg contained in the msg bits and operations buffer
let modulus_sup = (message_modulus.0 * carry_modulus.0) as usize;
let input_modulus_sup = (input_message_modulus.0 * input_carry_modulus.0) as usize;

// N/(p/2) = size of each block
let box_size = polynomial_size.0 / modulus_sup;
let box_size = polynomial_size.0 / input_modulus_sup;

// Value of the shift we multiply our messages by
let delta = (1_u64 << 63) / (message_modulus.0 * carry_modulus.0);
let output_delta = (1_u64 << 63) / (output_message_modulus.0 * output_carry_modulus.0);

let mut body = accumulator_view.get_mut_body();
let accumulator_u64 = body.as_mut();

// Tracking the max value of the function to define the degree later
let mut max_value = 0;

for i in 0..modulus_sup {
for i in 0..input_modulus_sup {
let index = i * box_size;
let f_eval = f(i as u64);
max_value = max_value.max(f_eval);
accumulator_u64[index..index + box_size].fill(f_eval * delta);
accumulator_u64[index..index + box_size].fill(f_eval * output_delta);
}

let half_box_size = box_size / 2;
Expand Down
49 changes: 40 additions & 9 deletions tfhe/src/shortint/list_compression/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use crate::core_crypto::prelude::{
};
use crate::shortint::ciphertext::CompressedCiphertextList;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::NoiseLevel;
use crate::shortint::parameters::{CarryModulus, MessageModulus, NoiseLevel};
use crate::shortint::server_key::{
apply_programmable_bootstrap, generate_lookup_table, unchecked_scalar_mul_assign,
apply_programmable_bootstrap, generate_lookup_table_with_encoding, unchecked_scalar_mul_assign,
};
use crate::shortint::{Ciphertext, CiphertextModulus, MaxNoiseLevel};
use rayon::iter::ParallelIterator;
Expand Down Expand Up @@ -126,18 +126,49 @@ impl CompressionKey {
}

impl DecompressionKey {
pub fn unpack(&self, packed: &CompressedCiphertextList, index: usize) -> Option<Ciphertext> {
pub fn unpack(
&self,
packed: &CompressedCiphertextList,
index: usize,
) -> Result<Ciphertext, crate::Error> {
if packed.message_modulus.0 != packed.carry_modulus.0 {
return Err(crate::Error::new(format!(
"Tried to unpack values from a list where message modulus \
({:?}) is != carry modulus ({:?}), this is not supported.",
packed.message_modulus, packed.carry_modulus,
)));
}

if index >= packed.count.0 {
return None;
return Err(crate::Error::new(format!(
"Tried getting index {index} for CompressedCiphertextList \
with {} elements, out of bound access.",
packed.count.0
)));
}

let carry_extract = generate_lookup_table(
let encryption_cleartext_modulus = packed.message_modulus.0 * packed.carry_modulus.0;
// We multiply by message_modulus during compression so the actual modulus for the
// compression is smaller
let compression_cleartext_modulus = encryption_cleartext_modulus / packed.message_modulus.0;
let effective_compression_message_modulus = MessageModulus(compression_cleartext_modulus);
let effective_compression_carry_modulus = CarryModulus(1);

let decompression_rescale = generate_lookup_table_with_encoding(
self.out_glwe_size(),
self.out_polynomial_size(),
packed.ciphertext_modulus,
// Input moduli are the effective compression ones
effective_compression_message_modulus,
effective_compression_carry_modulus,
// Output moduli are directly the ones stored in the list
packed.message_modulus,
packed.carry_modulus,
|x| x / packed.message_modulus.0,
// Here we do not divide by message_modulus
// Example: in the 2_2 case we are mapping a 2 bits message onto a 4 bits space, we
// want to keep the original 2 bits value in the 4 bits space, so we apply the identity
// and the encoding will rescale it for us.
|x| x,
);

let polynomial_size = packed.modulus_switched_glwe_ciphertext_list[0].polynomial_size();
Expand Down Expand Up @@ -181,14 +212,14 @@ impl DecompressionKey {
&self.blind_rotate_key,
&intermediate_lwe,
&mut output_br,
&carry_extract.acc,
&decompression_rescale.acc,
buffers,
);
});

Some(Ciphertext::new(
Ok(Ciphertext::new(
output_br,
carry_extract.degree,
decompression_rescale.degree,
NoiseLevel::NOMINAL,
packed.message_modulus,
packed.carry_modulus,
Expand Down
52 changes: 50 additions & 2 deletions tfhe/src/shortint/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,17 @@ use crate::core_crypto::prelude::ComputationBuffers;
use crate::shortint::ciphertext::{Ciphertext, Degree, MaxDegree, MaxNoiseLevel, NoiseLevel};
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::{
fill_accumulator, fill_accumulator_no_encoding, fill_many_lut_accumulator, ShortintEngine,
fill_accumulator, fill_accumulator_no_encoding, fill_accumulator_with_encoding,
fill_many_lut_accumulator, ShortintEngine,
};
use crate::shortint::parameters::{
CarryModulus, CiphertextConformanceParams, CiphertextModulus, MessageModulus,
};
use crate::shortint::{EncryptionKeyChoice, PBSOrder};
use ::tfhe_versionable::Versionize;
use aligned_vec::ABox;
use serde::{Deserialize, Serialize};
use std::fmt::{Debug, Display, Formatter};
use tfhe_versionable::Versionize;

#[cfg(feature = "pbs-stats")]
pub mod pbs_stats {
Expand Down Expand Up @@ -1563,6 +1564,53 @@ where
}
}

/// Caller needs to ensure that the operation applied is coherent from an encoding perspective.
///
/// For example:
///
/// Input encoding has 2 bits and output encoding has 4 bits, applying the identity lut would map
/// the following:
///
/// 0|00|xx -> 0|00|00
/// 0|01|xx -> 0|00|01
/// 0|10|xx -> 0|00|10
/// 0|11|xx -> 0|00|11
///
/// The reason is the identity function is computed in the input space but the scaling is done in
/// the output space, as there are more bits in the output space, the delta is smaller hence the
/// apparent "division" happening.
#[allow(clippy::too_many_arguments)]
pub(crate) fn generate_lookup_table_with_encoding<F>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ciphertext_modulus: CiphertextModulus,
input_message_modulus: MessageModulus,
input_carry_modulus: CarryModulus,
output_message_modulus: MessageModulus,
output_carry_modulus: CarryModulus,
f: F,
) -> LookupTableOwned
where
F: Fn(u64) -> u64,
{
let mut acc = GlweCiphertext::new(0, glwe_size, polynomial_size, ciphertext_modulus);
let max_value = fill_accumulator_with_encoding(
&mut acc,
polynomial_size,
glwe_size,
input_message_modulus,
input_carry_modulus,
output_message_modulus,
output_carry_modulus,
f,
);

LookupTableOwned {
acc,
degree: Degree::new(max_value),
}
}

#[derive(Copy, Clone)]
pub struct PBSConformanceParameters {
pub in_lwe_dimension: LweDimension,
Expand Down

0 comments on commit 8054368

Please sign in to comment.