From 31171d5bda91a6d34396435e287f3d0259dcd0a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Nicolas?= Date: Wed, 15 Jan 2025 15:19:30 +0100 Subject: [PATCH] aurel/hawk-main: Start HNSW MPC node --- iris-mpc-cpu/Cargo.toml | 8 +- iris-mpc-cpu/bin/hawk_main.rs | 8 ++ iris-mpc-cpu/src/execution/hawk_main.rs | 174 ++++++++++++++++++++++++ iris-mpc-cpu/src/execution/mod.rs | 1 + 4 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 iris-mpc-cpu/bin/hawk_main.rs create mode 100644 iris-mpc-cpu/src/execution/hawk_main.rs diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index dcfdceb1e..000298be3 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -8,11 +8,11 @@ license.workspace = true repository.workspace = true [dependencies] -aes-prng = { git = "https://github.com/tf-encrypted/aes-prng.git", branch = "dragos/display"} +aes-prng = { git = "https://github.com/tf-encrypted/aes-prng.git", branch = "dragos/display" } async-channel = "2.3.1" async-stream = "0.3.6" async-trait = "~0.1" -backoff = {version="0.4.0", features = ["tokio"]} +backoff = { version = "0.4.0", features = ["tokio"] } bincode.workspace = true bytes = "1.7" bytemuck.workspace = true @@ -59,3 +59,7 @@ path = "bin/local_hnsw.rs" [[bin]] name = "generate_benchmark_data" path = "bin/generate_benchmark_data.rs" + +[[bin]] +name = "hawk_main" +path = "bin/hawk_main.rs" diff --git a/iris-mpc-cpu/bin/hawk_main.rs b/iris-mpc-cpu/bin/hawk_main.rs new file mode 100644 index 000000000..fcd7b608d --- /dev/null +++ b/iris-mpc-cpu/bin/hawk_main.rs @@ -0,0 +1,8 @@ +use clap::Parser; +use eyre::Result; +use iris_mpc_cpu::execution::hawk_main::{hawk_main, HawkArgs}; + +#[tokio::main] +async fn main() -> Result<()> { + hawk_main(HawkArgs::parse()).await +} diff --git a/iris-mpc-cpu/src/execution/hawk_main.rs b/iris-mpc-cpu/src/execution/hawk_main.rs new file mode 100644 index 000000000..af76e37eb --- /dev/null +++ b/iris-mpc-cpu/src/execution/hawk_main.rs @@ -0,0 +1,174 @@ +use crate::{ + database_generators::generate_galois_iris_shares, + execution::{ + local::generate_local_identities, + player::{Role, RoleAssignment}, + session::{BootSession, Session, SessionId}, + }, + hawkers::aby3_store::{Aby3Store, SharedIrisesRef}, + hnsw::HnswSearcher, + network::grpc::{GrpcConfig, GrpcNetworking}, + proto_generated::party_node::party_node_server::PartyNodeServer, + protocol::ops::setup_replicated_prf, +}; +use aes_prng::AesRng; +use clap::Parser; +use eyre::Result; +use hawk_pack::graph_store::GraphMem; +use iris_mpc_common::iris_db::db::IrisDB; +use itertools::{izip, Itertools}; +use rand::SeedableRng; +use std::{sync::Arc, time::Duration}; +use tokio::task::JoinSet; +use tonic::transport::Server; + +#[derive(Parser)] +pub struct HawkArgs { + #[clap(short, long)] + party_index: usize, +} + +pub async fn hawk_main(args: HawkArgs) -> Result<()> { + // ---- Shared setup ---- + + let n_parties = 3; + let port_start = 40000; + let search_params = HnswSearcher::default(); + + let identities = generate_local_identities(); + + let addresses = (0..n_parties) + .map(|i| format!("127.0.0.1:{}", port_start + i)) + .collect_vec(); + + let role_assignments: RoleAssignment = identities + .iter() + .enumerate() + .map(|(index, id)| (Role::new(index), id.clone())) + .collect(); + + // ---- My network setup ---- + + let my_index = args.party_index; + let my_identity = identities[my_index].clone(); + let my_address = addresses[my_index].clone(); + + println!("🦅 Starting Hawk node {my_index}"); + + let grpc_config = GrpcConfig { + timeout_duration: Duration::from_secs(1), + }; + + let player = GrpcNetworking::new(my_identity.clone(), grpc_config); + + // Start server. + { + let player = player.clone(); + let socket = my_address.parse().unwrap(); + tokio::spawn(async move { + Server::builder() + .add_service(PartyNodeServer::new(player)) + .serve(socket) + .await + .unwrap(); + }); + } + + // TODO: Retry until all servers are up. + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + // Connect to other players. + izip!(&identities, &addresses) + .filter(|(_, address)| address != &&my_address) + .map(|(identity, address)| { + let player = player.clone(); + let identity = identity.clone(); + let url = format!("http://{}", address); + async move { player.connect_to_party(identity, &url).await } + }) + .collect::>() + .join_all() + .await + .into_iter() + .collect::>()?; + + // ---- My state ---- + // TODO: Persistence. + + let iris_store = SharedIrisesRef::default(); + let mut graph_store = GraphMem::::new(); + let mut graph_rng = AesRng::seed_from_u64(123); + + // ---- MPC session ---- + // TODO: Manage parallel sessions. + + let session_id = SessionId::from(0_u64); + let my_session_seed = [0_u8; 16]; + + player.create_session(session_id).await?; + + let boot_session = BootSession { + session_id, + role_assignments: Arc::new(role_assignments.clone()), + networking: Arc::new(player.clone()), + own_identity: my_identity.clone(), + }; + + let prf = setup_replicated_prf(&boot_session, my_session_seed).await?; + + let session = Session { + boot_session, + setup: prf, + }; + + let mut aby3_store = Aby3Store { + session, + storage: iris_store, + owner: my_identity, + }; + assert_eq!(aby3_store.get_owner_index(), my_index); + + // ---- Requests ---- + // TODO: Listen for external requests. + + let n_inserts = 10; + let iris_rng = &mut AesRng::seed_from_u64(1337); + + let my_iris_shares = IrisDB::new_random_rng(n_inserts, iris_rng) + .db + .into_iter() + .map(|iris| generate_galois_iris_shares(iris_rng, iris)[my_index].clone()) + .collect_vec(); + + for iris in my_iris_shares { + let query = aby3_store.prepare_query(iris); + + search_params + .insert(&mut aby3_store, &mut graph_store, &query, &mut graph_rng) + .await; + } + + println!("🎉 Inserted {n_inserts} items into the database"); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_hawk_main() -> Result<()> { + ["0", "1", "2"] + .into_iter() + .map(|index| async move { + let args = HawkArgs::parse_from(&["hawk_main", "--party-index", index]); + + hawk_main(args).await + }) + .collect::>() + .join_all() + .await + .into_iter() + .collect::>() + } +} diff --git a/iris-mpc-cpu/src/execution/mod.rs b/iris-mpc-cpu/src/execution/mod.rs index f6c2aefbe..b79e343f1 100644 --- a/iris-mpc-cpu/src/execution/mod.rs +++ b/iris-mpc-cpu/src/execution/mod.rs @@ -1,3 +1,4 @@ pub mod local; pub mod player; pub mod session; +pub mod hawk_main;