Skip to content

Commit

Permalink
feat(core): add batched programmable boostraping
Browse files Browse the repository at this point in the history
  • Loading branch information
soonum committed Oct 18, 2024
1 parent f3a1b6b commit 54ef252
Show file tree
Hide file tree
Showing 6 changed files with 555 additions and 23 deletions.
171 changes: 148 additions & 23 deletions tfhe/benches/core_crypto/pbs_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,28 @@ use tfhe::core_crypto::prelude::*;
use tfhe::keycache::NamedParam;
use tfhe::shortint::parameters::*;

const SHORTINT_BENCH_PARAMS: [ClassicPBSParameters; 19] = [
PARAM_MESSAGE_1_CARRY_0_KS_PBS,
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_0_KS_PBS,
PARAM_MESSAGE_2_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_3_CARRY_0_KS_PBS,
PARAM_MESSAGE_3_CARRY_2_KS_PBS,
PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_0_KS_PBS,
PARAM_MESSAGE_4_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_4_KS_PBS,
PARAM_MESSAGE_5_CARRY_0_KS_PBS,
PARAM_MESSAGE_6_CARRY_0_KS_PBS,
PARAM_MESSAGE_7_CARRY_0_KS_PBS,
PARAM_MESSAGE_8_CARRY_0_KS_PBS,
PARAM_MESSAGE_1_CARRY_1_PBS_KS,
PARAM_MESSAGE_2_CARRY_2_PBS_KS,
PARAM_MESSAGE_3_CARRY_3_PBS_KS,
PARAM_MESSAGE_4_CARRY_4_PBS_KS,
];
// const SHORTINT_BENCH_PARAMS: [ClassicPBSParameters; 19] = [
// PARAM_MESSAGE_1_CARRY_0_KS_PBS,
// PARAM_MESSAGE_1_CARRY_1_KS_PBS,
// PARAM_MESSAGE_2_CARRY_0_KS_PBS,
// PARAM_MESSAGE_2_CARRY_1_KS_PBS,
// PARAM_MESSAGE_2_CARRY_2_KS_PBS,
// PARAM_MESSAGE_3_CARRY_0_KS_PBS,
// PARAM_MESSAGE_3_CARRY_2_KS_PBS,
// PARAM_MESSAGE_3_CARRY_3_KS_PBS,
// PARAM_MESSAGE_4_CARRY_0_KS_PBS,
// PARAM_MESSAGE_4_CARRY_3_KS_PBS,
// PARAM_MESSAGE_4_CARRY_4_KS_PBS,
// PARAM_MESSAGE_5_CARRY_0_KS_PBS,
// PARAM_MESSAGE_6_CARRY_0_KS_PBS,
// PARAM_MESSAGE_7_CARRY_0_KS_PBS,
// PARAM_MESSAGE_8_CARRY_0_KS_PBS,
// PARAM_MESSAGE_1_CARRY_1_PBS_KS,
// PARAM_MESSAGE_2_CARRY_2_PBS_KS,
// PARAM_MESSAGE_3_CARRY_3_PBS_KS,
// PARAM_MESSAGE_4_CARRY_4_PBS_KS,
// ];
const SHORTINT_BENCH_PARAMS: [ClassicPBSParameters; 1] = [PARAM_MESSAGE_2_CARRY_2_KS_PBS];

const BOOLEAN_BENCH_PARAMS: [(&str, BooleanParameters); 2] = [
("BOOLEAN_DEFAULT_PARAMS", DEFAULT_PARAMETERS),
Expand Down Expand Up @@ -216,10 +217,133 @@ fn mem_optimized_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
{
bench_group.bench_function(&id, |b| {
b.iter(|| {
programmable_bootstrap_lwe_ciphertext_mem_optimized(
for _ in 0..10 {
programmable_bootstrap_lwe_ciphertext_mem_optimized(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator.as_view(),
&fourier_bsk,
fft,
buffers.stack(),
);
black_box(&mut out_pbs_ct);
}
})
});
}

let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
write_to_json(
&id,
*params,
name,
"pbs",
&OperatorType::Atomic,
bit_size,
vec![bit_size],
);
}
}

fn mem_optimized_batched_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
c: &mut Criterion,
parameters: &[(String, CryptoParametersRecord<Scalar>)],
) {
let bench_name = "core_crypto::batched_pbs_mem_optimized";
let mut bench_group = c.benchmark_group(bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(10));

// Create the PRNG
let mut seeder = new_seeder();
let seeder = seeder.as_mut();
let mut encryption_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
let mut secret_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());

for (name, params) in parameters.iter() {
// Create the LweSecretKey
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
params.lwe_dimension.unwrap(),
&mut secret_generator,
);
let output_glwe_secret_key: GlweSecretKeyOwned<Scalar> =
allocate_and_generate_new_binary_glwe_secret_key(
params.glwe_dimension.unwrap(),
params.polynomial_size.unwrap(),
&mut secret_generator,
);
let output_lwe_secret_key = output_glwe_secret_key.into_lwe_secret_key();

// Create the empty bootstrapping key in the Fourier domain
let fourier_bsk = FourierLweBootstrapKey::new(
params.lwe_dimension.unwrap(),
params.glwe_dimension.unwrap().to_glwe_size(),
params.polynomial_size.unwrap(),
params.pbs_base_log.unwrap(),
params.pbs_level.unwrap(),
);

let count = 10; // FIXME Is it a representative value (big enough?)

// Allocate a new LweCiphertext and encrypt our plaintext
let mut lwe_ciphertext_in = LweCiphertextListOwned::<Scalar>::new(
Scalar::ZERO,
input_lwe_secret_key.lwe_dimension().to_lwe_size(),
LweCiphertextCount(count),
params.ciphertext_modulus.unwrap(),
);

encrypt_lwe_ciphertext_list(
&input_lwe_secret_key,
&mut lwe_ciphertext_in,
&PlaintextList::from_container(vec![Scalar::ZERO; count]),
params.lwe_noise_distribution.unwrap(),
&mut encryption_generator,
);

let accumulator = GlweCiphertextList::new(
Scalar::ZERO,
params.glwe_dimension.unwrap().to_glwe_size(),
params.polynomial_size.unwrap(),
GlweCiphertextCount(count),
params.ciphertext_modulus.unwrap(),
);

// Allocate the LweCiphertext to store the result of the PBS
let mut out_pbs_ct = LweCiphertextList::new(
Scalar::ZERO,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
LweCiphertextCount(count),
params.ciphertext_modulus.unwrap(),
);

let mut buffers = ComputationBuffers::new();

let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();

buffers.resize(
batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<Scalar>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required()
* count,
);

let id = format!("{bench_name}::{name}");
{
bench_group.bench_function(&id, |b| {
b.iter(|| {
batched_programmable_bootstrap_lwe_ciphertext_mem_optimized(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator.as_view(),
&accumulator,
&fourier_bsk,
fft,
buffers.stack(),
Expand Down Expand Up @@ -1310,6 +1434,7 @@ pub fn pbs_group() {
mem_optimized_pbs(&mut criterion, &benchmark_parameters_64bits());
mem_optimized_pbs(&mut criterion, &benchmark_parameters_32bits());
mem_optimized_pbs_ntt(&mut criterion);
mem_optimized_batched_pbs(&mut criterion, &benchmark_parameters_64bits());
}

pub fn multi_bit_pbs_group() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1071,3 +1071,77 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<OutputSca
) -> Result<StackReq, SizeOverflow> {
bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
}

/// Memory optimized version of [`batch_programmable_bootstrap_lwe_ciphertext`], the caller must provide
/// a properly configured [`FftView`] object and a `PodStack` used as a memory buffer having a
/// capacity at least as large as the result of
/// [`crate::core_crypto::algorithms::batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement`].
pub fn batch_programmable_bootstrap_lwe_ciphertext_mem_optimized<
InputScalar,
OutputScalar,
InputCont,
OutputCont,
AccCont,
KeyCont,
>(
input: &LweCiphertextList<InputCont>,
output: &mut LweCiphertextList<OutputCont>,
accumulator: &GlweCiphertextList<AccCont>,
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
fft: FftView<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
InputScalar: UnsignedTorus + CastInto<usize>,
OutputScalar: UnsignedTorus,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
AccCont: Container<Element = OutputScalar>,
KeyCont: Container<Element = c64>,
{
assert_eq!(
accumulator.ciphertext_modulus(),
output.ciphertext_modulus(),
"Mismatched moduli between accumulator ({:?}) and output ({:?})",
accumulator.ciphertext_modulus(),
output.ciphertext_modulus()
);

assert_eq!(
fourier_bsk.input_lwe_dimension(),
input.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
FourierLweBootstrapKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
fourier_bsk.input_lwe_dimension(),
input.lwe_size().to_lwe_dimension(),
);
assert_eq!(
fourier_bsk.output_lwe_dimension(),
output.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
FourierLweBootstrapKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
fourier_bsk.output_lwe_dimension(),
output.lwe_size().to_lwe_dimension(),
);

fourier_bsk.as_view().batch_bootstrap(
output.as_mut_view(),
input.as_view(),
&accumulator.as_view(),
fft,
stack,
);
}

/// Return the required memory for [`batch_programmable_bootstrap_lwe_ciphertext_mem_optimized`].
pub fn batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<OutputScalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<OutputScalar>(
glwe_size,
polynomial_size,
fft,
)
}
Loading

0 comments on commit 54ef252

Please sign in to comment.