Skip to content

Commit

Permalink
Change chunk size for Prio and fix various clippy warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Feb 5, 2024
1 parent 927f62e commit 091d43c
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 33 deletions.
17 changes: 9 additions & 8 deletions src/bin/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@ impl PlaintextReport {
..
} => [
ReportShare::Mastic {
nonce: nonce.clone(),
nonce: *nonce,
vidpf_key: vidpf_keys[0].clone(),
flp_proof_share: flp_proof_shares[0].clone(),
flp_joint_rand_parts: flp_joint_rand_parts.clone(),
flp_joint_rand_parts: *flp_joint_rand_parts,
},
ReportShare::Mastic {
nonce: nonce.clone(),
nonce: *nonce,
vidpf_key: vidpf_keys[1].clone(),
flp_proof_share: flp_proof_shares[1].clone(),
flp_joint_rand_parts: flp_joint_rand_parts.clone(),
flp_joint_rand_parts: *flp_joint_rand_parts,
},
],
Self::Prio3 {
Expand All @@ -115,12 +115,12 @@ impl PlaintextReport {
input_shares,
} => [
ReportShare::Prio3 {
nonce: nonce.clone(),
nonce: *nonce,
public_share_bytes: public_share.get_encoded(),
input_share_bytes: input_shares[0].get_encoded(),
},
ReportShare::Prio3 {
nonce: nonce.clone(),
nonce: *nonce,
public_share_bytes: public_share.get_encoded(),
input_share_bytes: input_shares[1].get_encoded(),
},
Expand Down Expand Up @@ -199,7 +199,8 @@ fn generate_reports(cfg: &config::Config, mastic: &MasticHistogram) -> Vec<Plain
}
}
Mode::PlainMetrics => {
let chunk_length = histogram_chunk_length(mastic.input_len());
let chunk_length =
histogram_chunk_length(mastic.input_len(), Mode::PlainMetrics);
let prio3 = Prio3::new_histogram(2, mastic.input_len(), chunk_length).unwrap();
let (public_share, input_shares) = prio3.shard(&bucket, &nonce).unwrap();

Expand Down Expand Up @@ -616,7 +617,7 @@ async fn run_plain_metrics(
client_1: &CollectorClient,
num_clients: usize,
) -> io::Result<()> {
let chunk_length = histogram_chunk_length(mastic.input_len());
let chunk_length = histogram_chunk_length(mastic.input_len(), Mode::PlainMetrics);
let prio3 = Prio3::new_histogram(2, mastic.input_len(), chunk_length).unwrap();

for start in (0..num_clients).step_by(cfg.flp_batch_size) {
Expand Down
24 changes: 13 additions & 11 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use futures::{future, prelude::*};
use itertools::Itertools;
use mastic::{
collect::{self, ReportShare},
config, histogram_chunk_length, prg,
config::{self, Mode},
histogram_chunk_length, prg,
rpc::{
AddReportSharesRequest, ApplyFLPResultsRequest, AttributeBasedMetricsResultRequest,
AttributeBasedMetricsValidateRequest, Collector, FinalSharesRequest, GetProofsRequest,
Expand Down Expand Up @@ -174,7 +175,7 @@ impl Collector for CollectorServer {
.mastic
.query(
&beta_share,
&coll.report_shares[client_index].1.unwrap_flp_proof_share(),
coll.report_shares[client_index].1.unwrap_flp_proof_share(),
&query_rand,
&joint_rand,
2,
Expand All @@ -185,7 +186,7 @@ impl Collector for CollectorServer {
client_index,
values_share,
verifier_share,
eval_proof.finalize().as_bytes().clone(),
*eval_proof.finalize().as_bytes(),
)
}));

Expand All @@ -208,10 +209,11 @@ impl Collector for CollectorServer {
let mut coll = self.arc.lock().unwrap();

for rejected_client_index in req.rejected {
debug_assert!(coll
.attribute_based_metrics_state
.remove(&rejected_client_index)
.is_some());
debug_assert!(
coll.attribute_based_metrics_state
.remove(&rejected_client_index)
.is_some()
);
}

let mut agg_share =
Expand All @@ -234,7 +236,7 @@ impl Collector for CollectorServer {
let mut coll = self.arc.lock().unwrap();
let mut results = Vec::with_capacity(req.end - req.start);
let agg_id = self.server_id.try_into().unwrap();
let chunk_length = histogram_chunk_length(coll.mastic.input_len());
let chunk_length = histogram_chunk_length(coll.mastic.input_len(), Mode::PlainMetrics);
let prio3 = Prio3::new_histogram(2, coll.mastic.input_len(), chunk_length).unwrap();

results.par_extend((req.start..req.end).into_par_iter().map(|client_index| {
Expand All @@ -248,9 +250,9 @@ impl Collector for CollectorServer {
};

let public_share =
Prio3PublicShare::get_decoded_with_param(&prio3, &public_share_bytes).unwrap();
Prio3PublicShare::get_decoded_with_param(&prio3, public_share_bytes).unwrap();
let input_share =
Prio3InputShare::get_decoded_with_param(&(&prio3, agg_id), &input_share_bytes)
Prio3InputShare::get_decoded_with_param(&(&prio3, agg_id), input_share_bytes)
.unwrap();

let (prep_state, prep_share) = prio3
Expand Down Expand Up @@ -284,7 +286,7 @@ impl Collector for CollectorServer {
) -> (AggregateShare<Field128>, usize) {
debug_assert!(req.start < req.end);
let mut coll = self.arc.lock().unwrap();
let chunk_length = histogram_chunk_length(coll.mastic.input_len());
let chunk_length = histogram_chunk_length(coll.mastic.input_len(), Mode::PlainMetrics);
let prio3 = Prio3::new_histogram(2, coll.mastic.input_len(), chunk_length).unwrap();

let out_shares = (req.start..req.end)
Expand Down
9 changes: 4 additions & 5 deletions src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ impl ReportShare {

fn nonce(&self) -> &[u8; 16] {
match self {
Self::Mastic { nonce, .. } => &nonce,
Self::Prio3 { nonce, .. } => &nonce,
Self::Mastic { nonce, .. } => nonce,
Self::Prio3 { nonce, .. } => nonce,
}
}
}
Expand Down Expand Up @@ -289,10 +289,9 @@ impl KeyCollection {
/// derive the correct joint randomness part from its input share and send it to the other so
/// that they can check if the advertised parts were actually computed correctly.
pub fn flp_joint_rand(&self, client_index: usize) -> Vec<Field128> {
let mut jr_parts = self.report_shares[client_index]
let mut jr_parts = *self.report_shares[client_index]
.1
.unwrap_flp_joint_rand_parts()
.clone();
.unwrap_flp_joint_rand_parts();
if self.server_id == 0 {
let mut jr_part_xof = XofShake128::init(
&self.report_shares[client_index]
Expand Down
23 changes: 18 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod vidpf;

extern crate lazy_static;

use config::Mode;
use prio::{
field::Field128,
flp::{
Expand All @@ -20,10 +21,14 @@ pub use crate::rpc::CollectorClient;

pub const HASH_SIZE: usize = 16;

pub fn histogram_chunk_length(_num_buckets: usize) -> usize {
// NOTE(cjpatton) The "asymptotically optimal" chunk length is `(num_buckets as f64).sqrt() as
// usize`. However our histograms are so small that a constant size seems to perform better.
2
pub fn histogram_chunk_length(num_buckets: usize, mode: Mode) -> usize {
// The "asymptotically optimal" chunk length is `(num_buckets as f64).sqrt()
// as usize`. However Mastic histograms are so small that a constant size seems
// to perform better. For PlainMetrics, we use bigger histograms.
match mode {
Mode::WeightedHeavyHitters { .. } | Mode::AttributeBasedMetrics { .. } => 2,
Mode::PlainMetrics => (num_buckets as f64).sqrt() as usize,
}
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -113,7 +118,15 @@ impl MasticHistogram {
/// Constructs an instance of MasticHistogram with the given number of aggregators,
/// number of buckets, and parallel sum gadget chunk length.
pub fn new_histogram(length: usize) -> Result<Self, VdafError> {
Mastic::new(Histogram::new(length, histogram_chunk_length(length))?)
Mastic::new(Histogram::new(
length,
histogram_chunk_length(
length,
Mode::WeightedHeavyHitters {
threshold: 0.0, // Unused here.
},
),
)?)
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/vidpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl VidpfEvalNode {
input_len: usize,
eval_proof: &mut blake3::Hasher,
) -> &mut VidpfEvalNode {
if path.len() == 0 {
if path.is_empty() {
return self;
}

Expand All @@ -108,7 +108,7 @@ impl VidpfEvalNode {
if self.l.is_none() {
let l = self.next(false, key, input_len);
if let (Some(ref mut path_check), Some(ref word_share_l)) = (&mut p, &l.word_share) {
vec_sub(path_check, &word_share_l);
vec_sub(path_check, word_share_l);
}
self.l = Some(Box::new(l));
}
Expand All @@ -117,7 +117,7 @@ impl VidpfEvalNode {
if self.r.is_none() {
let r = self.next(true, key, input_len);
if let (Some(ref mut path_check), Some(ref word_share_r)) = (&mut p, &r.word_share) {
vec_sub(path_check, &word_share_r);
vec_sub(path_check, word_share_r);
}
self.r = Some(Box::new(r));
}
Expand Down Expand Up @@ -482,10 +482,10 @@ impl VidpfKey {

#[cfg(test)]
mod tests {
use crate::string_to_bits;
use prio::field::FieldElement;

use super::*;
use crate::string_to_bits;

/// Test the VIDPF functionality required for the attribute-based metrics use case.
#[test]
Expand Down

0 comments on commit 091d43c

Please sign in to comment.