diff --git a/tfhe/benches/integer/bench.rs b/tfhe/benches/integer/bench.rs index 4d4af899cc..4de4659688 100644 --- a/tfhe/benches/integer/bench.rs +++ b/tfhe/benches/integer/bench.rs @@ -1367,13 +1367,17 @@ mod cuda { use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; use tfhe::integer::gpu::server_key::CudaServerKey; - fn bench_cuda_server_key_unary_function_clean_inputs( + fn bench_cuda_server_key_unary_function_clean_inputs( c: &mut Criterion, bench_name: &str, display_name: &str, unary_op: F, + unary_op_cpu: G + // TODO add another argument `unary_op_cpu: G` ) where F: Fn(&CudaServerKey, &mut CudaUnsignedRadixCiphertext, &CudaStreams) + Sync, + G: Fn(&ServerKey, &mut RadixCiphertext) + Sync, + // TODO Add another generic to handle CPU function signature { let mut bench_group = c.benchmark_group(bench_name); bench_group @@ -1412,15 +1416,14 @@ mod cuda { }); } BenchmarkType::Throughput => { - let (cks, _cpu_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let (cks, cpu_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); let gpu_sks = CudaServerKey::new(&cks, &streams); - let ct = cks.encrypt_radix(gen_random_u256(&mut rng), num_block); - let mut d_ctxt = - CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams); + let clear_0 = gen_random_u256(&mut rng); + let mut ct_0 = cks.encrypt_radix(clear_0, num_block); reset_pbs_count(); - unary_op(&gpu_sks, &mut d_ctxt, &streams); + unary_op_cpu(&cpu_sks, &mut ct_0); let pbs_count = get_pbs_count(); bench_id = format!("{bench_name}::throughput::{param_name}::{bit_size}_bits"); @@ -1840,6 +1843,9 @@ mod cuda { stringify!($name), |server_key, lhs, stream| { server_key.$server_key_method(lhs, stream); + }, + |server_key_cpu, lhs| { + server_key_cpu.$server_key_method(lhs); } ) }