Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(shortint): fix compression encoding change not being taken into account #1860

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions tfhe/src/shortint/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,33 @@ pub(crate) fn fill_accumulator<F, C>(
carry_modulus: CarryModulus,
f: F,
) -> u64
where
C: ContainerMut<Element = u64>,
F: Fn(u64) -> u64,
{
fill_accumulator_with_encoding(
accumulator,
polynomial_size,
glwe_size,
message_modulus,
carry_modulus,
message_modulus,
carry_modulus,
f,
)
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn fill_accumulator_with_encoding<F, C>(
accumulator: &mut GlweCiphertext<C>,
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
input_message_modulus: MessageModulus,
input_carry_modulus: CarryModulus,
output_message_modulus: MessageModulus,
output_carry_modulus: CarryModulus,
f: F,
) -> u64
where
C: ContainerMut<Element = u64>,
F: Fn(u64) -> u64,
Expand All @@ -97,25 +124,25 @@ where
accumulator_view.get_mut_mask().as_mut().fill(0);

// Modulus of the msg contained in the msg bits and operations buffer
let modulus_sup = (message_modulus.0 * carry_modulus.0) as usize;
let input_modulus_sup = (input_message_modulus.0 * input_carry_modulus.0) as usize;

// N/(p/2) = size of each block
let box_size = polynomial_size.0 / modulus_sup;
let box_size = polynomial_size.0 / input_modulus_sup;

// Value of the shift we multiply our messages by
let delta = (1_u64 << 63) / (message_modulus.0 * carry_modulus.0);
let output_delta = (1_u64 << 63) / (output_message_modulus.0 * output_carry_modulus.0);

let mut body = accumulator_view.get_mut_body();
let accumulator_u64 = body.as_mut();

// Tracking the max value of the function to define the degree later
let mut max_value = 0;

for i in 0..modulus_sup {
for i in 0..input_modulus_sup {
let index = i * box_size;
let f_eval = f(i as u64);
max_value = max_value.max(f_eval);
accumulator_u64[index..index + box_size].fill(f_eval * delta);
accumulator_u64[index..index + box_size].fill(f_eval * output_delta);
}

let half_box_size = box_size / 2;
Expand Down
49 changes: 40 additions & 9 deletions tfhe/src/shortint/list_compression/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use crate::core_crypto::prelude::{
};
use crate::shortint::ciphertext::CompressedCiphertextList;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::NoiseLevel;
use crate::shortint::parameters::{CarryModulus, MessageModulus, NoiseLevel};
use crate::shortint::server_key::{
apply_programmable_bootstrap, generate_lookup_table, unchecked_scalar_mul_assign,
apply_programmable_bootstrap, generate_lookup_table_with_encoding, unchecked_scalar_mul_assign,
};
use crate::shortint::{Ciphertext, CiphertextModulus, MaxNoiseLevel};
use rayon::iter::ParallelIterator;
Expand Down Expand Up @@ -126,18 +126,49 @@ impl CompressionKey {
}

impl DecompressionKey {
pub fn unpack(&self, packed: &CompressedCiphertextList, index: usize) -> Option<Ciphertext> {
pub fn unpack(
&self,
packed: &CompressedCiphertextList,
index: usize,
) -> Result<Ciphertext, crate::Error> {
if packed.message_modulus.0 != packed.carry_modulus.0 {
return Err(crate::Error::new(format!(
"Tried to unpack values from a list where message modulus \
({:?}) is != carry modulus ({:?}), this is not supported.",
packed.message_modulus, packed.carry_modulus,
)));
}

if index >= packed.count.0 {
return None;
return Err(crate::Error::new(format!(
"Tried getting index {index} for CompressedCiphertextList \
with {} elements, out of bound access.",
packed.count.0
)));
}

let carry_extract = generate_lookup_table(
let encryption_cleartext_modulus = packed.message_modulus.0 * packed.carry_modulus.0;
// We multiply by message_modulus during compression so the actual modulus for the
// compression is smaller
let compression_cleartext_modulus = encryption_cleartext_modulus / packed.message_modulus.0;
IceTDrinker marked this conversation as resolved.
Show resolved Hide resolved
let effective_compression_message_modulus = MessageModulus(compression_cleartext_modulus);
let effective_compression_carry_modulus = CarryModulus(1);

let decompression_rescale = generate_lookup_table_with_encoding(
self.out_glwe_size(),
self.out_polynomial_size(),
packed.ciphertext_modulus,
// Input moduli are the effective compression ones
effective_compression_message_modulus,
effective_compression_carry_modulus,
// Output moduli are directly the ones stored in the list
packed.message_modulus,
packed.carry_modulus,
|x| x / packed.message_modulus.0,
// Here we do not divide by message_modulus
// Example: in the 2_2 case we are mapping a 2 bits message onto a 4 bits space, we
// want to keep the original 2 bits value in the 4 bits space, so we apply the identity
// and the encoding will rescale it for us.
|x| x,
);

let polynomial_size = packed.modulus_switched_glwe_ciphertext_list[0].polynomial_size();
Expand Down Expand Up @@ -181,14 +212,14 @@ impl DecompressionKey {
&self.blind_rotate_key,
&intermediate_lwe,
&mut output_br,
&carry_extract.acc,
&decompression_rescale.acc,
buffers,
);
});

Some(Ciphertext::new(
Ok(Ciphertext::new(
output_br,
carry_extract.degree,
decompression_rescale.degree,
NoiseLevel::NOMINAL,
packed.message_modulus,
packed.carry_modulus,
Expand Down
52 changes: 50 additions & 2 deletions tfhe/src/shortint/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,17 @@ use crate::core_crypto::prelude::ComputationBuffers;
use crate::shortint::ciphertext::{Ciphertext, Degree, MaxDegree, MaxNoiseLevel, NoiseLevel};
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::{
fill_accumulator, fill_accumulator_no_encoding, fill_many_lut_accumulator, ShortintEngine,
fill_accumulator, fill_accumulator_no_encoding, fill_accumulator_with_encoding,
fill_many_lut_accumulator, ShortintEngine,
};
use crate::shortint::parameters::{
CarryModulus, CiphertextConformanceParams, CiphertextModulus, MessageModulus,
};
use crate::shortint::{EncryptionKeyChoice, PBSOrder};
use ::tfhe_versionable::Versionize;
use aligned_vec::ABox;
use serde::{Deserialize, Serialize};
use std::fmt::{Debug, Display, Formatter};
use tfhe_versionable::Versionize;

#[cfg(feature = "pbs-stats")]
pub mod pbs_stats {
Expand Down Expand Up @@ -1563,6 +1564,53 @@ where
}
}

/// Caller needs to ensure that the operation applied is coherent from an encoding perspective.
///
/// For example:
///
/// Input encoding has 2 bits and output encoding has 4 bits, applying the identity lut would map
/// the following:
///
/// 0|00|xx -> 0|00|00
/// 0|01|xx -> 0|00|01
/// 0|10|xx -> 0|00|10
/// 0|11|xx -> 0|00|11
///
/// The reason is the identity function is computed in the input space but the scaling is done in
/// the output space, as there are more bits in the output space, the delta is smaller hence the
/// apparent "division" happening.
#[allow(clippy::too_many_arguments)]
pub(crate) fn generate_lookup_table_with_encoding<F>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ciphertext_modulus: CiphertextModulus,
input_message_modulus: MessageModulus,
input_carry_modulus: CarryModulus,
output_message_modulus: MessageModulus,
output_carry_modulus: CarryModulus,
f: F,
) -> LookupTableOwned
where
F: Fn(u64) -> u64,
{
let mut acc = GlweCiphertext::new(0, glwe_size, polynomial_size, ciphertext_modulus);
let max_value = fill_accumulator_with_encoding(
&mut acc,
polynomial_size,
glwe_size,
input_message_modulus,
input_carry_modulus,
output_message_modulus,
output_carry_modulus,
f,
);

LookupTableOwned {
acc,
degree: Degree::new(max_value),
}
}

#[derive(Copy, Clone)]
pub struct PBSConformanceParameters {
pub in_lwe_dimension: LweDimension,
Expand Down
Loading