Skip to content

Commit

Permalink
SIMD backend for poseidon252 merkle ops - CPU IMPL (#809)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/809)
<!-- Reviewable:end -->
  • Loading branch information
shaharsamocha7 committed Aug 29, 2024
2 parents e3858fb + b15bcbd commit 92e2001
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 27 deletions.
2 changes: 2 additions & 0 deletions crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ mod circle;
mod fri;
mod grind;
pub mod lookups;
#[cfg(not(target_arch = "wasm32"))]
mod poseidon252;
pub mod quotients;

use std::fmt::Debug;
Expand Down
24 changes: 24 additions & 0 deletions crates/prover/src/core/backend/cpu/poseidon252.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use itertools::Itertools;
use starknet_ff::FieldElement as FieldElement252;

use super::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher;

impl MerkleOps<Poseidon252MerkleHasher> for CpuBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<FieldElement252>>,
columns: &[&Vec<BaseField>],
) -> Vec<FieldElement252> {
(0..(1 << log_size))
.map(|i| {
Poseidon252MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}
6 changes: 6 additions & 0 deletions crates/prover/src/core/backend/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize};

use super::{Backend, BackendForChannel};
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;

pub mod accumulation;
pub mod bit_reverse;
Expand All @@ -15,6 +17,8 @@ pub mod fri;
mod grind;
pub mod lookups;
pub mod m31;
#[cfg(not(target_arch = "wasm32"))]
pub mod poseidon252;
pub mod prefix_sum;
pub mod qm31;
pub mod quotients;
Expand All @@ -26,3 +30,5 @@ pub struct SimdBackend;

impl Backend for SimdBackend {}
impl BackendForChannel<Blake2sMerkleChannel> for SimdBackend {}
#[cfg(not(target_arch = "wasm32"))]
impl BackendForChannel<Poseidon252MerkleChannel> for SimdBackend {}
36 changes: 36 additions & 0 deletions crates/prover/src/core/backend/simd/poseidon252.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use itertools::Itertools;
use starknet_ff::FieldElement as FieldElement252;

use super::SimdBackend;
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::ops::MerkleHasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher;

impl ColumnOps<FieldElement252> for SimdBackend {
type Column = Vec<FieldElement252>;

fn bit_reverse_column(_column: &mut Self::Column) {
unimplemented!()
}
}

impl MerkleOps<Poseidon252MerkleHasher> for SimdBackend {
// TODO(ShaharS): replace with SIMD implementation.
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<FieldElement252>>,
columns: &[&Col<Self, BaseField>],
) -> Vec<FieldElement252> {
(0..(1 << log_size))
.map(|i| {
Poseidon252MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column.at(i)).collect_vec(),
)
})
.collect()
}
}
21 changes: 1 addition & 20 deletions crates/prover/src/core/vcs/poseidon252_merkle.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use itertools::Itertools;
use num_traits::Zero;
use serde::{Deserialize, Serialize};
use starknet_crypto::{poseidon_hash, poseidon_hash_many};
use starknet_ff::FieldElement as FieldElement252;

use super::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::CpuBackend;
use super::ops::MerkleHasher;
use crate::core::channel::{MerkleChannel, Poseidon252Channel};
use crate::core::fields::m31::BaseField;
use crate::core::vcs::hash::Hash;
Expand Down Expand Up @@ -46,23 +44,6 @@ impl MerkleHasher for Poseidon252MerkleHasher {
}
}

impl MerkleOps<Poseidon252MerkleHasher> for CpuBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<FieldElement252>>,
columns: &[&Vec<BaseField>],
) -> Vec<FieldElement252> {
(0..(1 << log_size))
.map(|i| {
Poseidon252MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
})
.collect()
}
}

impl Hash for FieldElement252 {}

#[derive(Default)]
Expand Down
5 changes: 3 additions & 2 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, StarkProof};
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher};
use crate::core::{ColumnVec, InteractionElements};

pub type PlonkComponent = FrameworkComponent<PlonkEval>;
Expand Down Expand Up @@ -182,7 +182,8 @@ pub fn prove_fibonacci_plonk(

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles);
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Trace.
let span = span!(Level::INFO, "Trace").entered();
Expand Down
5 changes: 3 additions & 2 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::core::pcs::{CommitmentSchemeProver, PcsConfig};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, StarkProof};
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher};
use crate::core::{ColumnVec, InteractionElements};

const N_LOG_INSTANCES_PER_ROW: usize = 3;
Expand Down Expand Up @@ -336,7 +336,8 @@ pub fn prove_poseidon(

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles);
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);

// Trace.
let span = span!(Level::INFO, "Trace").entered();
Expand Down
2 changes: 0 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,6 @@ mod tests {
#[cfg(not(target_arch = "wasm32"))]
#[test_log::test]
fn test_single_instance_wide_fib_prove_with_poseidon() {
use crate::core::backend::CpuBackend;

const LOG_N_INSTANCES: u32 = 0;
let config = PcsConfig::default();
let component = WideFibComponent {
Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/examples/wide_fibonacci/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ mod tests {
span.exit();
let channel = &mut Blake2sChannel::default();
let air = SimdWideFibAir { component };
let proof = commit_and_prove(&air, channel, trace, config).unwrap();
let proof =
commit_and_prove::<_, Blake2sMerkleChannel>(&air, channel, trace, config).unwrap();

let channel = &mut Blake2sChannel::default();
commit_and_verify::<Blake2sMerkleChannel>(proof, &air, channel, config).unwrap();
Expand Down

0 comments on commit 92e2001

Please sign in to comment.