Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ns/zk curve conformance #1743

Merged
merged 2 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/aws_tfhe_fast_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ jobs:
run: |
make test_safe_deserialization

- name: Run zk tests
run: |
make test_zk

- name: Slack Notification
if: ${{ failure() }}
continue-on-error: true
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,11 @@ test_safe_deserialization: install_rs_build_toolchain install_cargo_nextest
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache -p $(TFHE_SPEC) -- safe_deserialization::

.PHONY: test_zk # Run the tests for the zk module of the TFHE-rs crate
test_zk: install_rs_build_toolchain install_cargo_nextest
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),shortint,zk-pok -p $(TFHE_SPEC) -- zk::

.PHONY: test_integer # Run all the tests for integer
test_integer: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
Expand Down
20 changes: 20 additions & 0 deletions tfhe-zk-pok/src/curve_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ pub trait CurveGroupOps<Zp>:
fn to_le_bytes(self) -> impl AsRef<[u8]>;
fn double(self) -> Self;
fn normalize(self) -> Self::Affine;
fn validate_projective(&self) -> bool {
Self::validate_affine(&self.normalize())
}
fn validate_affine(affine: &Self::Affine) -> bool;
mayeul-zama marked this conversation as resolved.
Show resolved Hide resolved
}

/// Mark that an element can be compressed, by storing only the 'x' coordinates of the affine
Expand Down Expand Up @@ -231,6 +235,10 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G1 {
inner: self.inner.into_affine(),
}
}

fn validate_affine(affine: &Self::Affine) -> bool {
affine.validate()
}
}

impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
Expand Down Expand Up @@ -271,6 +279,10 @@ impl CurveGroupOps<bls12_381::Zp> for bls12_381::G2 {
inner: self.inner.into_affine(),
}
}

fn validate_affine(affine: &Self::Affine) -> bool {
affine.validate()
}
}

impl PairingGroupOps<bls12_381::Zp, bls12_381::G1, bls12_381::G2> for bls12_381::Gt {
Expand Down Expand Up @@ -368,6 +380,10 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G1 {
inner: self.inner.into_affine(),
}
}

fn validate_affine(affine: &Self::Affine) -> bool {
affine.validate()
}
}

impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
Expand Down Expand Up @@ -408,6 +424,10 @@ impl CurveGroupOps<bls12_446::Zp> for bls12_446::G2 {
inner: self.inner.into_affine(),
}
}

fn validate_affine(affine: &Self::Affine) -> bool {
affine.validate()
}
}

impl PairingGroupOps<bls12_446::Zp, bls12_446::G1, bls12_446::G2> for bls12_446::Gt {
Expand Down
8 changes: 8 additions & 0 deletions tfhe-zk-pok/src/curve_api/bls12_381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ mod g1 {
.unwrap(),
}
}

pub fn validate(&self) -> bool {
self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve()
}
}

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
Expand Down Expand Up @@ -310,6 +314,10 @@ mod g2 {
.unwrap(),
}
}

pub fn validate(&self) -> bool {
self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve()
}
}

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
Expand Down
8 changes: 8 additions & 0 deletions tfhe-zk-pok/src/curve_api/bls12_446.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ mod g1 {
.unwrap(),
}
}

pub fn validate(&self) -> bool {
self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve()
}
}

#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash, Versionize)]
Expand Down Expand Up @@ -316,6 +320,10 @@ mod g2 {
}
}

pub fn validate(&self) -> bool {
self.inner.is_on_curve() && self.inner.is_in_correct_subgroup_assuming_on_curve()
}

// m is an intermediate variable that's used in both the curve point addition and pairing
// functions. we cache it since it requires a Zp division
// https://hackmd.io/@tazAymRSQCGXTUKkbh1BAg/Sk27liTW9#Math-Formula-for-Point-Addition
Expand Down
57 changes: 57 additions & 0 deletions tfhe-zk-pok/src/proofs/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::backward_compatibility::GroupElementsVersions;

use crate::curve_api::{Compressible, Curve, CurveGroupOps, FieldOps, PairingGroupOps};
use crate::serialization::{
InvalidSerializedGroupElementsError, SerializableG1Affine, SerializableG2Affine,
SerializableGroupElements,
};
use core::ops::{Index, IndexMut};
use rand::{Rng, RngCore};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use tfhe_versionable::Versionize;

#[derive(Clone, Copy, Debug, serde::Serialize, serde::Deserialize, Versionize)]
Expand Down Expand Up @@ -108,6 +110,16 @@ impl<G: Curve> GroupElements<G> {
message_len,
}
}

/// Check if the elements are valid for their respective groups
pub fn is_valid(&self) -> bool {
let (g_list_valid, g_hat_list_valid) = rayon::join(
|| self.g_list.0.par_iter().all(G::G1::validate_affine),
|| self.g_hat_list.0.par_iter().all(G::G2::validate_affine),
);

g_list_valid && g_hat_list_valid
}
}

impl<G: Curve> Compressible for GroupElements<G>
Expand Down Expand Up @@ -152,6 +164,8 @@ mod test {
#![allow(non_snake_case)]
use std::fmt::Display;

use ark_ec::{short_weierstrass, CurveConfig};
use ark_ff::UniformRand;
use bincode::ErrorKind;
use rand::rngs::StdRng;
use rand::Rng;
Expand Down Expand Up @@ -359,4 +373,47 @@ mod test {
PkeTestCiphertext { c1, c2 }
}
}

/// Return a point with coordinates (x, y) that is randomly chosen and not on the curve
pub(super) fn point_not_on_curve<Config: short_weierstrass::SWCurveConfig>(
mayeul-zama marked this conversation as resolved.
Show resolved Hide resolved
rng: &mut StdRng,
) -> short_weierstrass::Affine<Config> {
loop {
let fake_x = <Config as CurveConfig>::BaseField::rand(rng);
let fake_y = <Config as CurveConfig>::BaseField::rand(rng);

let point = short_weierstrass::Affine::new_unchecked(fake_x, fake_y);

if !point.is_on_curve() {
return point;
}
}
}

/// Return a random point on the curve
pub(super) fn point_on_curve<Config: short_weierstrass::SWCurveConfig>(
rng: &mut StdRng,
) -> short_weierstrass::Affine<Config> {
loop {
let x = <Config as CurveConfig>::BaseField::rand(rng);
let is_positive = bool::rand(rng);
if let Some(point) =
short_weierstrass::Affine::get_point_from_x_unchecked(x, is_positive)
{
return point;
}
}
}

/// Return a random point that is on the curve but not in the correct subgroup
pub(super) fn point_on_curve_wrong_subgroup<Config: short_weierstrass::SWCurveConfig>(
rng: &mut StdRng,
) -> short_weierstrass::Affine<Config> {
loop {
let point = point_on_curve(rng);
if !Config::is_in_correct_subgroup_assuming_on_curve(&point) {
return point;
}
}
}
}
Loading
Loading