From 255525c98f09ce91a93844d50ace7ee695f80aed Mon Sep 17 00:00:00 2001 From: bhargav Date: Tue, 28 Nov 2023 21:50:33 -0800 Subject: [PATCH 1/2] chore: par feature tags --- Cargo.toml | 3 +++ README.md | 4 +-- src/ntt.rs | 64 +++++++++++++++++++++++++++++++++++++++++++++++ src/polynomial.rs | 33 ++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1bf54c5..2423a14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,6 @@ criterion = { version = "0.5.1", features = ["html_reports"] } [[bench]] name = "benchmark" harness = false + +[features] +parallel = [] diff --git a/README.md b/README.md index c921a96..e86c4a1 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,9 @@ Generate benchmarks using: ```bash # If you don't have it already -cargo install cargo-criterion criterion-table +cargo install cargo-criterion criterion-table --cfg -cargo criterion --message-format=json | criterion-table > BENCHMARKS.md +cargo criterion --message-format=json --features parallel | criterion-table > BENCHMARKS.md ``` Benchmarks are also available [here](./BENCHMARKS.md) diff --git a/src/ntt.rs b/src/ntt.rs index cc3e937..c93e5d4 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -73,6 +73,7 @@ fn order_reverse(inp: &mut Vec) { }); } +#[cfg(feature = "parallel")] fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { assert!(inp.len().is_power_of_two()); let mut inp = inp.clone(); @@ -123,10 +124,62 @@ fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { inp } +#[cfg(not(feature = "parallel"))] +fn fft(inp: Vec, c: &Constants, w: BigInt) -> Vec { + assert!(inp.len().is_power_of_two()); + let mut inp = inp.clone(); + let N = inp.len(); + let MOD = BigInt::from(c.N); + let ONE = BigInt::from(1); + let mut pre: Vec = vec![ONE; N / 2]; + let CHUNK_COUNT = 128; + let chunk_count = BigInt::from(CHUNK_COUNT); + + pre.chunks_mut(CHUNK_COUNT) + .enumerate() + .for_each(|(i, arr)| arr[0] = w.mod_exp(BigInt::from(i) * chunk_count, MOD)); + pre.chunks_mut(CHUNK_COUNT).for_each(|x| { + (1..x.len()).for_each(|y| { + let _x = x.to_vec(); + x[y] = (w * x[y - 1]).rem(MOD); + }) + }); + order_reverse(&mut inp); + + let mut gap = 1; + + while gap < inp.len() { + let nchunks = inp.len() / (2 * gap); + inp.chunks_mut(2 * gap).for_each(|cxi| { + let (lo, hi) = cxi.split_at_mut(gap); + lo.iter_mut() + .zip(hi) + .enumerate() + .for_each(|(idx, (lo, hi))| { + *hi = (*hi * pre[nchunks * idx]).rem(MOD); + let neg = if *lo < *hi { + (MOD + *lo) - *hi + } else { + *lo - *hi + }; + *lo = if *lo + *hi >= MOD { + (*lo + *hi) - MOD + } else { + *lo + *hi + }; + *hi = neg; + }); + }); + gap *= 2; + } + inp +} + pub fn forward(inp: Vec, c: &Constants) -> Vec { fft(inp, c, c.w) } +#[cfg(feature = "parallel")] pub fn inverse(inp: Vec, c: &Constants) -> Vec { let mut inv = BigInt::from(inp.len()); let _ = inv.set_mod(c.N); @@ -137,6 +190,17 @@ pub fn inverse(inp: Vec, c: &Constants) -> Vec { res } +#[cfg(not(feature = "parallel"))] +pub fn inverse(inp: Vec, c: &Constants) -> Vec { + let mut inv = BigInt::from(inp.len()); + let _ = inv.set_mod(c.N); + let inv = inv.invert(); + let w = c.w.invert(); + let mut res = fft(inp, c, w); + res.iter_mut().for_each(|x| *x = (inv * (*x)).rem(c.N)); + res +} + #[cfg(test)] mod tests { use rand::Rng; diff --git a/src/polynomial.rs b/src/polynomial.rs index 628b7a5..6871221 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -50,6 +50,7 @@ impl Polynomial { Polynomial { coef: out } } + #[cfg(feature = "parallel")] pub fn mul(self, rhs: Polynomial, c: &Constants) -> Polynomial { let v1_deg = self.degree(); let v2_deg = rhs.degree(); @@ -81,6 +82,38 @@ impl Polynomial { } } + #[cfg(not(feature = "parallel"))] + pub fn mul(self, rhs: Polynomial, c: &Constants) -> Polynomial { + let v1_deg = self.degree(); + let v2_deg = rhs.degree(); + let n = (self.len() + rhs.len()).next_power_of_two(); + let ZERO = BigInt::from(0); + + let v1 = vec![ZERO; n - self.len()] + .into_iter() + .chain(self.coef.into_iter()) + .collect(); + let v2 = vec![ZERO; n - rhs.len()] + .into_iter() + .chain(rhs.coef.into_iter()) + .collect(); + + let a_forward = forward(v1, &c); + let b_forward = forward(v2, &c); + + let mut mul = vec![ZERO; n as usize]; + mul.iter_mut() + .enumerate() + .for_each(|(i, x)| *x = (a_forward[i] * b_forward[i]).rem(c.N)); + + let coef = inverse(mul, &c); + // n - polynomial degree - 1 + let start = n - (v1_deg + v2_deg + 1) - 1; + Polynomial { + coef: coef[start..=(start + v1_deg + v2_deg)].to_vec(), + } + } + pub fn diff(mut self) -> Self { let N = self.len(); for n in (1..N).rev() { From 5283c9e61e1ade571c1e5e1e54cf40acf60ab737 Mon Sep 17 00:00:00 2001 From: bhargav Date: Wed, 29 Nov 2023 01:12:15 -0800 Subject: [PATCH 2/2] feat: pow --- src/numbers.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/numbers.rs b/src/numbers.rs index 750dad1..2a7ad4d 100644 --- a/src/numbers.rs +++ b/src/numbers.rs @@ -75,6 +75,12 @@ impl BigInt { *self.v.params() } + pub fn pow(&self, n: u128) -> BigInt { + BigInt { + v: self.v.pow(&Uint::<4>::from_u128(n)), + } + } + pub fn mod_exp(&self, exp: BigInt, M: BigInt) -> BigInt { let mut res: BigInt = if !exp.is_even() { self.clone()