Skip to content

Commit

Permalink
poseidoninternalrelation
Browse files Browse the repository at this point in the history
  • Loading branch information
rw0x0 committed Sep 30, 2024
1 parent 7de8544 commit b67c914
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 24 deletions.
1 change: 1 addition & 0 deletions co-noir/co-ultrahonk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ ark-bn254.workspace = true
eyre.workspace = true
itertools.workspace = true
mpc-core.workspace = true
num-bigint.workspace = true
tracing.workspace = true
ultrahonk.workspace = true
17 changes: 9 additions & 8 deletions co-noir/co-ultrahonk/src/co_decider/co_sumcheck/round.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{
co_decider::{
relations::{
poseidon2_external_relation::Poseidon2ExternalRelation, AllRelationAcc, Relation,
poseidon2_external_relation::Poseidon2ExternalRelation,
poseidon2_internal_relation::Poseidon2InternalRelation, AllRelationAcc, Relation,
},
types::{ProverUnivariates, RelationParameters, MAX_PARTIAL_RELATION_LENGTH},
univariates::SharedUnivariate,
Expand Down Expand Up @@ -213,13 +214,13 @@ impl SumcheckRound {
relation_parameters,
scaling_factor,
)?;
// Self::accumulate_one_relation_univariates::<_, _, Poseidon2InternalRelation>(
// driver,
// &mut univariate_accumulators.r_pos_int,
// extended_edges,
// relation_parameters,
// scaling_factor,
// )?;
Self::accumulate_one_relation_univariates::<_, _, Poseidon2InternalRelation>(
driver,
&mut univariate_accumulators.r_pos_int,
extended_edges,
relation_parameters,
scaling_factor,
)?;
Ok(())
}

Expand Down
30 changes: 23 additions & 7 deletions co-noir/co-ultrahonk/src/co_decider/relations/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod poseidon2_external_relation;
pub(crate) mod poseidon2_internal_relation;

use super::{
co_sumcheck::round::SumcheckRoundOutput,
Expand All @@ -8,6 +9,7 @@ use super::{
use ark_ec::pairing::Pairing;
use mpc_core::traits::PrimeFieldMpcProtocol;
use poseidon2_external_relation::{Poseidon2ExternalRelation, Poseidon2ExternalRelationAcc};
use poseidon2_internal_relation::Poseidon2InternalRelationAcc;
use ultrahonk::prelude::{HonkCurve, HonkProofResult, TranscriptFieldType, Univariate};

pub(crate) trait Relation<T, P: HonkCurve<TranscriptFieldType>>
Expand Down Expand Up @@ -49,7 +51,14 @@ pub(crate) struct AllRelationAcc<T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
// pub(crate) r_arith: UltraArithmeticRelationAcc<T, P>,
// pub(crate) r_perm: UltraPermutationRelationAcc<T, P>,
// pub(crate) r_delta: DeltaRangeConstraintRelationAcc<T, P>,
// pub(crate) r_elliptic: EllipticRelationAcc<T, P>,
// pub(crate) r_aux: AuxiliaryRelationAcc<T, P>,
// pub(crate) r_lookup: LogDerivLookupRelationAcc<T, P>,
pub(crate) r_pos_ext: Poseidon2ExternalRelationAcc<T, P>,
pub(crate) r_pos_int: Poseidon2InternalRelationAcc<T, P>,
}

impl<T, P: Pairing> Default for AllRelationAcc<T, P>
Expand All @@ -58,7 +67,14 @@ where
{
fn default() -> Self {
Self {
// r_arith: Default::default(),
// r_perm: Default::default(),
// r_delta: Default::default(),
// r_elliptic: Default::default(),
// r_aux: Default::default(),
// r_lookup: Default::default(),
r_pos_ext: Default::default(),
r_pos_int: Default::default(),
}
}
}
Expand All @@ -81,7 +97,7 @@ where
// self.r_aux.scale(driver, &elements[9..15]);
// self.r_lookup.scale(driver, &elements[15..17]);
self.r_pos_ext.scale(driver, &elements[17..21]);
// self.r_pos_int.scale(driver, &elements[21..]);
self.r_pos_int.scale(driver, &elements[21..]);
}

pub(crate) fn extend_and_batch_univariates<const SIZE: usize>(
Expand Down Expand Up @@ -133,11 +149,11 @@ where
extended_random_poly,
partial_evaluation_result,
);
// self.r_pos_int.extend_and_batch_univariates(
// driver,
// result,
// extended_random_poly,
// partial_evaluation_result,
// );
self.r_pos_int.extend_and_batch_univariates(
driver,
result,
extended_random_poly,
partial_evaluation_result,
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ where
let s3 = w_o.add_public(driver, q_o);
let s4 = w_4.add_public(driver, q_4);

// todo!("Poseidon2ExternalRelation::accumulate");

// apply s-box round
let s = SharedUnivariate::univariates_to_vec(&[s1, s2, s3, s4]);
let u = driver.mul_many(&s, &s)?;
Expand All @@ -167,7 +165,6 @@ where
let u = SharedUnivariate::vec_to_univariates(&u);

// matrix mul v = M_E * u with 14 additions

let t0 = u[0].add(driver, &u[1]); // u_1 + u_2
let t1 = u[2].add(driver, &u[3]); // u_3 + u_4
let t2 = u[1].add(driver, &u[1]); // 2u_2
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
use super::Relation;
use crate::co_decider::{
types::{ProverUnivariates, RelationParameters},
univariates::SharedUnivariate,
};
use ark_ec::pairing::Pairing;
use ark_ff::{PrimeField, Zero};
use mpc_core::traits::PrimeFieldMpcProtocol;
use num_bigint::BigUint;
use ultrahonk::prelude::{
HonkCurve, HonkProofResult, TranscriptFieldType, Univariate, POSEIDON2_BN254_T4_PARAMS,
};

#[derive(Clone, Debug)]
pub(crate) struct Poseidon2InternalRelationAcc<T, P: Pairing>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
pub(crate) r0: SharedUnivariate<T, P, 7>,
pub(crate) r1: SharedUnivariate<T, P, 7>,
pub(crate) r2: SharedUnivariate<T, P, 7>,
pub(crate) r3: SharedUnivariate<T, P, 7>,
}

impl<T, P: Pairing> Default for Poseidon2InternalRelationAcc<T, P>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
fn default() -> Self {
Self {
r0: Default::default(),
r1: Default::default(),
r2: Default::default(),
r3: Default::default(),
}
}
}
impl<T, P: Pairing> Poseidon2InternalRelationAcc<T, P>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
pub(crate) fn scale(&mut self, driver: &mut T, elements: &[P::ScalarField]) {
assert!(elements.len() == Poseidon2InternalRelation::NUM_RELATIONS);
self.r0.scale(driver, &elements[0]);
self.r1.scale(driver, &elements[1]);
self.r2.scale(driver, &elements[2]);
self.r3.scale(driver, &elements[3]);
}

pub(crate) fn extend_and_batch_univariates<const SIZE: usize>(
&self,
driver: &mut T,
result: &mut SharedUnivariate<T, P, SIZE>,
extended_random_poly: &Univariate<P::ScalarField, SIZE>,
partial_evaluation_result: &P::ScalarField,
) {
self.r0.extend_and_batch_univariates(
driver,
result,
extended_random_poly,
partial_evaluation_result,
true,
);

self.r1.extend_and_batch_univariates(
driver,
result,
extended_random_poly,
partial_evaluation_result,
true,
);

self.r2.extend_and_batch_univariates(
driver,
result,
extended_random_poly,
partial_evaluation_result,
true,
);

self.r3.extend_and_batch_univariates(
driver,
result,
extended_random_poly,
partial_evaluation_result,
true,
);
}
}

pub(crate) struct Poseidon2InternalRelation {}

impl Poseidon2InternalRelation {
pub(crate) const NUM_RELATIONS: usize = 4;
}

impl<T, P: HonkCurve<TranscriptFieldType>> Relation<T, P> for Poseidon2InternalRelation
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
type Acc = Poseidon2InternalRelationAcc<T, P>;
const SKIPPABLE: bool = true;

fn skip(input: &ProverUnivariates<T, P>) -> bool {
<Self as Relation<T, P>>::check_skippable();
input.precomputed.q_poseidon2_internal().is_zero()
}

/**
* @brief Expression for the poseidon2 internal round relation, based on I_i in Section 6 of
* https://eprint.iacr.org/2023/323.pdf.
* @details This relation is defined as C(in(X)...) :=
* q_poseidon2_internal * ( (v1 - w_1_shift) + \alpha * (v2 - w_2_shift) +
* \alpha^2 * (v3 - w_3_shift) + \alpha^3 * (v4 - w_4_shift) ) = 0 where:
* u1 := (w_1 + q_1)^5
* sum := u1 + w_2 + w_3 + w_4
* v1 := u1 * D1 + sum
* v2 := w_2 * D2 + sum
* v3 := w_3 * D3 + sum
* v4 := w_4 * D4 + sum
* Di is the ith internal diagonal value - 1 of the internal matrix M_I
*
* @param evals transformed to `evals + C(in(X)...)*scaling_factor`
* @param in an std::array containing the fully extended Univariate edges.
* @param parameters contains beta, gamma, and public_input_delta, ....
* @param scaling_factor optional term to scale the evaluation before adding to evals.
*/
fn accumulate(
driver: &mut T,
univariate_accumulator: &mut Self::Acc,
input: &ProverUnivariates<T, P>,
_relation_parameters: &RelationParameters<P::ScalarField>,
scaling_factor: &P::ScalarField,
) -> HonkProofResult<()> {
tracing::trace!("Accumulate Poseidon2InternalRelation");

let w_l = input.witness.w_l();
let w_r = input.witness.w_r();
let w_o = input.witness.w_o();
let w_4 = input.witness.w_4();
let w_l_shift = input.shifted_witness.w_l();
let w_r_shift = input.shifted_witness.w_r();
let w_o_shift = input.shifted_witness.w_o();
let w_4_shift = input.shifted_witness.w_4();
let q_l = input.precomputed.q_l();
let q_poseidon2_internal = input.precomputed.q_poseidon2_internal();

// add round constants
let s1 = w_l.add_public(driver, q_l);

// apply s-box round
let u1 = driver.mul_many(s1.as_ref(), s1.as_ref())?;
let u1 = driver.mul_many(u1.as_ref(), u1.as_ref())?;
let u1 = driver.mul_many(u1.as_ref(), s1.as_ref())?;
let mut u2 = w_r.to_owned();
let mut u3 = w_o.to_owned();
let mut u4 = w_4.to_owned();
let mut u1 = SharedUnivariate::from_vec(&u1);

// matrix mul with v = M_I * u 4 muls and 7 additions
let sum = u1.add(driver, &u2);
let sum = sum.add(driver, &u3);
let sum = sum.add(driver, &u4);

let q_pos_by_scaling = q_poseidon2_internal.to_owned() * scaling_factor;

// TODO this poseidon instance is very hardcoded to the bn254 curve
let internal_matrix_diag_0 = P::ScalarField::from(BigUint::from(
POSEIDON2_BN254_T4_PARAMS.mat_internal_diag_m_1[0],
));
let internal_matrix_diag_1 = P::ScalarField::from(BigUint::from(
POSEIDON2_BN254_T4_PARAMS.mat_internal_diag_m_1[1],
));
let internal_matrix_diag_2 = P::ScalarField::from(BigUint::from(
POSEIDON2_BN254_T4_PARAMS.mat_internal_diag_m_1[2],
));
let internal_matrix_diag_3 = P::ScalarField::from(BigUint::from(
POSEIDON2_BN254_T4_PARAMS.mat_internal_diag_m_1[3],
));

u1.scale(driver, &internal_matrix_diag_0);
let v1 = u1.add(driver, &sum);
let tmp = v1.sub(driver, w_l_shift);
let tmp = tmp.mul_public(driver, &q_pos_by_scaling);

for i in 0..univariate_accumulator.r0.evaluations.len() {
univariate_accumulator.r0.evaluations[i] = driver.add(
&univariate_accumulator.r0.evaluations[i],
&tmp.evaluations[i],
);
}

///////////////////////////////////////////////////////////////////////

u2.scale(driver, &internal_matrix_diag_1);
let v2 = u2.add(driver, &sum);
let tmp = v2.sub(driver, w_r_shift);
let tmp = tmp.mul_public(driver, &q_pos_by_scaling);

for i in 0..univariate_accumulator.r1.evaluations.len() {
univariate_accumulator.r1.evaluations[i] = driver.add(
&univariate_accumulator.r1.evaluations[i],
&tmp.evaluations[i],
);
}

///////////////////////////////////////////////////////////////////////

u3.scale(driver, &internal_matrix_diag_2);
let v3 = u3.add(driver, &sum);
let tmp = v3.sub(driver, w_o_shift);
let tmp = tmp.mul_public(driver, &q_pos_by_scaling);

for i in 0..univariate_accumulator.r2.evaluations.len() {
univariate_accumulator.r2.evaluations[i] = driver.add(
&univariate_accumulator.r2.evaluations[i],
&tmp.evaluations[i],
);
}

///////////////////////////////////////////////////////////////////////
u4.scale(driver, &internal_matrix_diag_3);
let v4 = u4.add(driver, &sum);
let tmp = v4.sub(driver, w_4_shift);
let tmp = tmp.mul_public(driver, &q_pos_by_scaling);

for i in 0..univariate_accumulator.r3.evaluations.len() {
univariate_accumulator.r3.evaluations[i] = driver.add(
&univariate_accumulator.r3.evaluations[i],
&tmp.evaluations[i],
);
}

Ok(())
}
}
16 changes: 16 additions & 0 deletions co-noir/co-ultrahonk/src/co_decider/univariates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ impl<T, P: Pairing, const SIZE: usize> SharedUnivariate<T, P, SIZE>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
pub(crate) fn from_vec(vec: &[T::FieldShare]) -> Self {
assert_eq!(vec.len(), SIZE);
let mut res = Self::default();
res.evaluations.clone_from_slice(vec);
res
}

pub(crate) fn scale(&mut self, driver: &mut T, rhs: &P::ScalarField) {
for i in 0..SIZE {
self.evaluations[i] = driver.mul_with_public(rhs, &self.evaluations[i]);
Expand Down Expand Up @@ -326,3 +333,12 @@ where
f.debug_list().entries(self.evaluations.iter()).finish()
}
}

impl<T, P: Pairing, const SIZE: usize> AsRef<[T::FieldShare]> for SharedUnivariate<T, P, SIZE>
where
T: PrimeFieldMpcProtocol<P::ScalarField>,
{
fn as_ref(&self) -> &[T::FieldShare] {
self.evaluations.as_ref()
}
}
Loading

0 comments on commit b67c914

Please sign in to comment.