Skip to content

Commit

Permalink
batching
Browse files Browse the repository at this point in the history
  • Loading branch information
carlomazzaferro committed Sep 20, 2024
1 parent c91f49f commit f2bd9fc
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 167 deletions.
14 changes: 0 additions & 14 deletions iris-mpc-upgrade/src/bin/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,3 @@ services:
environment:
POSTGRES_USER: "postgres"
POSTGRES_PASSWORD: "postgres"
db-ui:
depends_on:
- old-db-shares-1
- old-db-shares-2
- old-db-masks-1
- new-db-1
- new-db-2
- new-db-3
image: dpage/pgadmin4
ports:
- "15432:80"
environment:
PGADMIN_DEFAULT_EMAIL: "postgres@postgres.postgres"
PGADMIN_DEFAULT_PASSWORD: "postgres"
18 changes: 12 additions & 6 deletions iris-mpc-upgrade/src/bin/seed_v1_dbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ struct Args {
#[clap(long)]
fill_to: u64,

#[clap(long)]
create: bool,

#[clap(long)]
migrate: bool,

#[clap(long)]
side: String,
}
Expand Down Expand Up @@ -42,25 +48,25 @@ async fn main() -> eyre::Result<()> {
"{}/{}",
args.shares_db_urls[0], participant_one_shares_db_name
),
migrate: false,
create: false,
migrate: args.migrate,
create: args.create,
};
let shares_db_config1 = DbConfig {
url: format!(
"{}/{}",
args.shares_db_urls[1], participant_two_shares_db_name
),
migrate: false,
create: false,
migrate: args.migrate,
create: args.create,
};
let masks_db_config = DbConfig {
url: format!(
"{}/{}",
args.masks_db_url.clone(),
participant_one_masks_db_name
),
migrate: false,
create: false,
migrate: args.migrate,
create: args.create,
};

let shares_db0 = Db::new(&shares_db_config0).await?;
Expand Down
197 changes: 159 additions & 38 deletions iris-mpc-upgrade/src/bin/tcp_upgrade_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ use mpc_uniqueness_check::{bits::Bits, distance::EncodedBits};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use rustls::{pki_types::ServerName, ClientConfig};
use std::{array, convert::TryFrom, pin::Pin, sync::Arc};
use tokio::{io::AsyncWriteExt, net::TcpStream};
use std::{array, convert::TryFrom, pin::Pin, sync::Arc, time::Duration};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
time::timeout,
};
use tokio_rustls::{client::TlsStream, TlsConnector};
use tracing::error;

Expand Down Expand Up @@ -79,6 +83,10 @@ async fn main() -> eyre::Result<()> {
let mut server2 = prepare_tls_stream_for_writing(&args.server2, client_config.clone()).await?;
let mut server3 = prepare_tls_stream_for_writing(&args.server3, client_config).await?;

// let mut server1 = TcpStream::connect(&args.server1).await?;
// let mut server2 = TcpStream::connect(&args.server2).await?;
// let mut server3 = TcpStream::connect(&args.server3).await?;

tracing::info!("Connecting to servers and syncing migration task parameters...");
server1.write_u8(args.party_id).await?;
server2.write_u8(args.party_id).await?;
Expand Down Expand Up @@ -155,89 +163,202 @@ async fn main() -> eyre::Result<()> {
let num_iris_codes = end - start;
tracing::info!("Processing {} iris codes", num_iris_codes);

let batch_size = args.batch_size;
let mut batch = Vec::with_capacity(batch_size as usize);

while let Some(share_res) = shares_stream.next().await {
let (share_id, share) = share_res?;
let (mask_id, mask) = mask_stream
.next()
.await
.context("mask stream ended before share stream did")??;

eyre::ensure!(
share_id == mask_id,
"Share and mask streams out of sync: {} != {}",
share_id,
mask_id
);
let id = share_id;
tracing::info!("Processing id: {}", id);

// Prepare the shares and masks for this item
let [mask_share_a, mask_share_b, mask_share_c] =
get_shares_from_masks(args.party_id, id, &mask, &mut rng);
get_shares_from_masks(args.party_id, share_id, &mask, &mut rng);
let [iris_share_a, iris_share_b, iris_share_c] =
get_shares_from_shares(args.party_id, id, &share, &mut rng);
get_shares_from_shares(args.party_id, share_id, &share, &mut rng);

// Add to batch
batch.push((
iris_share_a,
iris_share_b,
iris_share_c,
mask_share_a,
mask_share_b,
mask_share_c,
));

// If the batch is full, send it and wait for the ACK
if batch.len() == batch_size as usize {
tracing::info!("Sending batch of size {}", batch_size);
send_batch_and_wait_for_ack(
args.party_id,
&mut server1,
&mut server2,
&mut server3,
&batch,
)
.await?;
batch.clear(); // Clear the batch once ACK is received
}
}
// Send the remaining elements in the last batch
println!("Batch size: {}", batch.len());
if !batch.is_empty() {
tracing::info!("Sending final batch of size {}", batch.len());
send_batch_and_wait_for_ack(
args.party_id,
&mut server1,
&mut server2,
&mut server3,
&batch,
)
.await?;
batch.clear();
}
println!("Final batch sent!!!!!!!!!!!!!!, waiting for acks");
wait_for_ack(&mut server1).await?;
println!("Server 1 ack received");
wait_for_ack(&mut server2).await?;
println!("Server 2 ack received");
wait_for_ack(&mut server3).await?;
println!("Server 3 ack received");
Ok(())
}

let mut errors = Vec::new();
async fn send_batch_and_wait_for_ack(
party_id: u8,
server1: &mut TlsStream<TcpStream>,
server2: &mut TlsStream<TcpStream>,
server3: &mut TlsStream<TcpStream>,
batch: &Vec<(
TwoToThreeIrisCodeMessage,
TwoToThreeIrisCodeMessage,
TwoToThreeIrisCodeMessage,
MaskShareMessage,
MaskShareMessage,
MaskShareMessage,
)>,
) -> eyre::Result<()> {
let mut errors = Vec::new();
let batch_size = batch.len();
// Send the batch size to all servers
let (batch_size_result_a, batch_size_result_b, batch_size_result_c) = (
server1.write_u8(batch_size as u8),
server2.write_u8(batch_size as u8),
server3.write_u8(batch_size as u8),
)
.join()
.await;

if let Err(e) = batch_size_result_a {
error!("Failed to send batch size to server1: {:?}", e);
errors.push(e.to_string());
}
if let Err(e) = batch_size_result_b {
error!("Failed to send batch size to server2: {:?}", e);
errors.push(e.to_string());
}
if let Err(e) = batch_size_result_c {
error!("Failed to send batch size to server3: {:?}", e);
errors.push(e.to_string());
}

let (result_share_a, result_share_b, result_share_c) = (
iris_share_a.send(&mut server1),
iris_share_b.send(&mut server2),
iris_share_c.send(&mut server3),
// Send the batch to all servers
for (iris1, iris2, iris3, mask1, mask2, mask3) in batch {
let (result_iris_a, result_iris_b, result_iris_c) = (
iris1.send(server1),
iris2.send(server2),
iris3.send(server3),
)
.join()
.await;

if let Err(e) = result_share_a {
error!("Failed to send message to server1 (party_id: 0): {:?}", e);
// Handle sending errors
if let Err(e) = result_iris_a {
error!("Failed to send message to server1: {:?}", e);
errors.push(e.to_string());
}

if let Err(e) = result_share_b {
error!("Failed to send message to server2 (party_id: 1): {:?}", e);
if let Err(e) = result_iris_b {
error!("Failed to send message to server2: {:?}", e);
errors.push(e.to_string());
}

if let Err(e) = result_share_c {
error!("Failed to send message to server3 (party_id: 2): {:?}", e);
if let Err(e) = result_iris_c {
error!("Failed to send message to server3: {:?}", e);
errors.push(e.to_string());
}

if args.party_id == 0 {
// Send mask shares (only by party_id 0)
if party_id == 0 {
let (result_mask_a, result_mask_b, result_mask_c) = (
mask_share_a.send(&mut server1),
mask_share_b.send(&mut server2),
mask_share_c.send(&mut server3),
mask1.send(server1),
mask2.send(server2),
mask3.send(server3),
)
.join()
.await;
if let Err(e) = result_mask_a {
error!("Failed to send message to server1 (party_id: 0): {:?}", e);
error!("Failed to send mask to server1: {:?}", e);
errors.push(e.to_string());
}

if let Err(e) = result_mask_b {
error!("Failed to send message to server2 (party_id: 1): {:?}", e);
error!("Failed to send mask to server2: {:?}", e);
errors.push(e.to_string());
}

if let Err(e) = result_mask_c {
error!("Failed to send message to server3 (party_id: 2): {:?}", e);
error!("Failed to send mask to server3: {:?}", e);
errors.push(e.to_string());
}
}
// If any errors occurred, return a combined error
if !errors.is_empty() {
// Combine all errors into one
let combined_error = errors.join("||");
// Return the combined error
error!(combined_error);
return Ok(());
}
}

tracing::info!("Finished id: {}", id);
// Handle acknowledgment: Wait for ACK from server 1 (assuming it sends ACK for
// all)
if !errors.is_empty() {
let combined_error = errors.join(" || ");
return Err(eyre::eyre!(combined_error));
}
tracing::info!("Processing done!");
wait_for_ack(server1).await?;
wait_for_ack(server2).await?;
wait_for_ack(server3).await?;
Ok(())
}

async fn wait_for_ack(server: &mut TlsStream<TcpStream>) -> eyre::Result<()> {
match timeout(Duration::from_secs(10), server.read_u8()).await {
Ok(Ok(1)) => {
// Ack received successfully
tracing::info!("ACK received for batch");
Ok(())
}
Ok(Ok(42)) => {
// Ack received successfully
tracing::info!("ACK received for final batch");
Ok(())
}
Ok(Ok(_)) => {
error!("Received invalid ACK");
Err(eyre::eyre!("Invalid ACK received"))
}
Ok(Err(e)) => {
error!("Error reading ACK: {:?}", e);
Err(e.into())
}
Err(_) => {
error!("ACK timeout");
Err(eyre::eyre!("ACK timeout"))
}
}
}

struct V1Database {
db: V1Db,
}
Expand Down
Loading

0 comments on commit f2bd9fc

Please sign in to comment.