diff --git a/engine-modexp/src/arith.rs b/engine-modexp/src/arith.rs index ca0c7de69..04a0b0218 100644 --- a/engine-modexp/src/arith.rs +++ b/engine-modexp/src/arith.rs @@ -158,32 +158,6 @@ pub fn mod_inv(x: Word) -> Word { y } -// Given x odd, computes `x^(-1) mod 2^(WORD_BYTES*out.digits.len())`. -// See `MODULAR-INVERSE` in https://link.springer.com/content/pdf/10.1007/3-540-46877-3_21.pdf -pub fn big_mod_inv(x: &MPNat, out: &mut MPNat, scratch: &mut [Word]) { - let s = out.digits.len(); - out.digits[0] = mod_inv(x.digits[0]); - - for digit_index in 1..s { - for i in 1..WORD_BITS { - let mask = (1 << i) - 1; - big_wrapping_mul(x, out, scratch); - scratch[digit_index] &= mask; - let q = 1 << (i - 1); - if scratch[digit_index] >= q { - out.digits[digit_index] += q; - } - scratch.fill(0); - } - big_wrapping_mul(x, out, scratch); - let q = 1 << (WORD_BITS - 1); - if scratch[digit_index] >= q { - out.digits[digit_index] += q; - } - scratch.fill(0); - } -} - /// Computes R mod n, where R = 2^(WORD_BITS*k) and k = n.digits.len() /// Note that if R = qn + r, q must be smaller than 2^WORD_BITS since `2^(WORD_BITS) * n > R` /// (adding a whole additional word to n is too much). @@ -294,14 +268,22 @@ pub fn big_sq(x: &MPNat, out: &mut [Word]) { out[i + i] = product; let mut c = carry as DoubleWord; for j in (i + 1)..s { - let product = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord); - let (product, overflow) = product.overflowing_add(product); - let sum = (out[i + j] as DoubleWord) + product + c; - out[i + j] = sum as Word; - c = (sum >> WORD_BITS) as DoubleWord; + let mut new_c: DoubleWord = 0; + let res = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord); + let (res, overflow) = res.overflowing_add(res); if overflow { - c += BASE; + new_c += BASE; } + let (res, overflow) = (out[i + j] as DoubleWord).overflowing_add(res); + if overflow { + new_c += BASE; + } + let (res, overflow) = res.overflowing_add(c); + if overflow { + new_c += BASE; + } + out[i + j] = res as Word; + c = new_c + ((res >> WORD_BITS) as DoubleWord); } let (sum, carry) = carrying_add(out[i + s], c as Word, false); out[i + s] = sum; @@ -351,6 +333,11 @@ pub fn in_place_add(a: &mut [Word], b: &[Word]) -> bool { pub fn in_place_mul_sub(a: &mut [Word], x: &[Word], y: Word) -> Word { debug_assert!(a.len() == x.len()); + // a -= x*0 leaves a unchanged, so return early + if y == 0 { + return 0; + } + // carry is between -big_digit::MAX and 0, so to avoid overflow we store // offset_carry = carry + big_digit::MAX let mut offset_carry = Word::MAX; @@ -504,31 +491,6 @@ fn test_r_mod_n() { } } -#[test] -fn test_big_mod_inv() { - check_big_mod_inv(0x02_FF_FF_FF); - check_big_mod_inv(0x1234_0000_DDDD_FFFF); - check_big_mod_inv(0x52DA_9A91_F82D_6E17_FDF8_6743_2B58_7917); - - fn check_big_mod_inv(n: u128) { - let x = MPNat::from_big_endian(&n.to_be_bytes()); - let s = x.digits.len(); - let mut result = MPNat { digits: vec![0; s] }; - let mut scratch = vec![0; s]; - big_mod_inv(&x, &mut result, &mut scratch); - let n_inv = mp_nat_to_u128(&result); - if WORD_BITS * s < u128::BITS as usize { - assert_eq!( - n.wrapping_mul(n_inv) % (1 << (WORD_BITS * s)), - 1, - "{n} failed big_mod_inv check" - ); - } else { - assert_eq!(n.wrapping_mul(n_inv), 1, "{n} failed big_mod_inv check"); - } - } -} - #[test] fn test_in_place_shl() { check_in_place_shl(0, 0); @@ -655,6 +617,24 @@ fn test_big_sq() { }; assert_eq!(result, expected, "{a}^2 != {expected}"); } + + /* Test for addition overflows in the big_sq inner loop */ + { + let x = MPNat::from_big_endian(&[ + 0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x40, 0x00, + 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00, + ]); + let mut out = vec![0; 2 * x.digits.len() + 1]; + big_sq(&x, &mut out); + let result = MPNat { digits: out }.to_big_endian(); + let expected = vec![ + 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0xff, 0xff, 0xff, 0xfe, 0x40, 0x00, 0x00, 0x01, 0x90, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xbf, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + assert_eq!(result, expected); + } } #[test] diff --git a/engine-modexp/src/mpnat.rs b/engine-modexp/src/mpnat.rs index 312b106aa..76c523017 100644 --- a/engine-modexp/src/mpnat.rs +++ b/engine-modexp/src/mpnat.rs @@ -1,8 +1,8 @@ use crate::{ arith::{ - big_mod_inv, big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add, - compute_r_mod_n, in_place_add, in_place_mul_sub, in_place_shl, in_place_shr, - join_as_double, mod_inv, monpro, monsq, + big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add, compute_r_mod_n, + in_place_add, in_place_mul_sub, in_place_shl, in_place_shr, join_as_double, mod_inv, + monpro, monsq, }, maybe_std::{vec, Vec}, }; @@ -22,6 +22,61 @@ pub struct MPNat { } impl MPNat { + // KoƧ's algorithm for inversion mod 2^k + // https://eprint.iacr.org/2017/411.pdf + fn koc_2017_inverse(aa: &Self, k: usize) -> Self { + debug_assert!(aa.is_odd()); + + let length = k / WORD_BITS; + let mut b = MPNat { + digits: vec![0; length + 1], + }; + b.digits[0] = 1; + + let mut a = MPNat { + digits: aa.digits.clone(), + }; + a.digits.resize(length + 1, 0); + + let mut neg: bool = false; + + let mut res = MPNat { + digits: vec![0; length + 1], + }; + + let (mut wordpos, mut bitpos) = (0, 0); + + for _ in 0..k { + let x = b.digits[0] & 1; + if x != 0 { + if !neg { + // b = a - b + let mut tmp = MPNat { + digits: a.digits.clone(), + }; + in_place_mul_sub(&mut tmp.digits, &b.digits, 1); + b = tmp; + neg = true; + } else { + // b = b - a + in_place_add(&mut b.digits, &a.digits); + } + } + + in_place_shr(&mut b.digits, 1); + + res.digits[wordpos] |= x << bitpos; + + bitpos += 1; + if bitpos == WORD_BITS { + bitpos = 0; + wordpos += 1; + } + } + + res + } + pub fn from_big_endian(bytes: &[u8]) -> Self { if bytes.is_empty() { return Self { digits: vec![0] }; @@ -174,14 +229,11 @@ impl MPNat { let x1 = base_copy.modpow_montgomery(exp, &odd); let x2 = self.modpow_with_power_of_two(exp, &power_of_two); + let odd_inv = + Self::koc_2017_inverse(&odd, trailing_zeros * WORD_BITS + additional_zero_bits); + let s = power_of_two.digits.len(); let mut scratch = vec![0; s]; - let odd_inv = { - let mut tmp = MPNat { digits: vec![0; s] }; - big_mod_inv(&odd, &mut tmp, &mut scratch); - *tmp.digits.last_mut().unwrap() &= power_of_two_mask; - tmp - }; let diff = { scratch.fill(0); let mut b = false; @@ -348,7 +400,7 @@ impl MPNat { return; } - let other_most_sig = *other.digits.last().unwrap(); + let other_most_sig = *other.digits.last().unwrap() as DoubleWord; if self.digits.len() == 2 { // This is the smallest case since `n >= 1` and `m > 0` @@ -357,7 +409,7 @@ impl MPNat { // to get the answer directly. let self_most_sig = self.digits.pop().unwrap(); let a = join_as_double(self_most_sig, self.digits[0]); - let b = other_most_sig as DoubleWord; + let b = other_most_sig; self.digits[0] = (a % b) as Word; return; } @@ -369,8 +421,7 @@ impl MPNat { for j in (0..k).rev() { let self_most_sig = self.digits.pop().unwrap(); let self_second_sig = self.digits[j]; - let r = - join_as_double(self_most_sig, self_second_sig) % (other_most_sig as DoubleWord); + let r = join_as_double(self_most_sig, self_second_sig) % other_most_sig; self.digits[j] = r as Word; } return; @@ -385,7 +436,7 @@ impl MPNat { // both numerator and denominator by a common factor // and run the algorithm on those numbers. // See Knuth The Art of Computer Programming vol. 2 section 4.3 for details. - let shift = other_most_sig.leading_zeros(); + let shift = (other_most_sig as Word).leading_zeros(); if shift > 0 { // Normalize self let overflow = in_place_shl(&mut self.digits, shift); @@ -410,34 +461,31 @@ impl MPNat { return; } - let other_second_sig = other.digits[n - 2]; + let other_second_sig = other.digits[n - 2] as DoubleWord; let mut self_most_sig: Word = 0; for j in (0..=m).rev() { let self_second_sig = *self.digits.last().unwrap(); let self_third_sig = self.digits[self.digits.len() - 2]; - let (mut q_hat, mut r_hat) = { - let a = join_as_double(self_most_sig, self_second_sig); - let mut q_hat = a / (other_most_sig as DoubleWord); - let mut r_hat = a % (other_most_sig as DoubleWord); + let a = join_as_double(self_most_sig, self_second_sig); + let mut q_hat = a / other_most_sig; + let mut r_hat = a % other_most_sig; - if q_hat == BASE { + loop { + let a = q_hat * other_second_sig; + let b = join_as_double(r_hat as Word, self_third_sig); + if q_hat >= BASE || a > b { q_hat -= 1; - r_hat += other_most_sig as DoubleWord; + r_hat += other_most_sig; + if BASE <= r_hat { + break; + } + } else { + break; } - - (q_hat as Word, r_hat) - }; - - while r_hat < BASE - && join_as_double(r_hat as Word, self_third_sig) - < (q_hat as DoubleWord) * (other_second_sig as DoubleWord) - { - q_hat -= 1; - r_hat += other_most_sig as DoubleWord; } - let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat); + let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat as Word); if borrow > self_most_sig { // q_hat was too large, add back one multiple of the modulus let carry = in_place_add(&mut self.digits[j..], &other.digits); @@ -660,6 +708,34 @@ fn test_sub_to_same_size() { let result = crate::arith::mp_nat_to_u128(&x); assert_eq!(result % n, a % n, "{a} % {n} failed sub_to_same_size check"); } + + /* Test that borrow equals self_most_sig at end of sub_to_same_size */ + { + let mut x = MPNat::from_big_endian(&[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02, + 0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x02, 0xc9, 0xd9, 0x12, 0x61, 0x8e, 0xf5, 0x02, 0x2c, + 0xa0, 0x00, 0x00, 0x00, + ]); + let y = MPNat::from_big_endian(&[ + 0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02, 0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x0f, 0x70, 0x00, + 0x00, 0x00, + ]); + x.sub_to_same_size(&y); + } + + /* Additional test for sub_to_same_size q_hat/r_hat adjustment logic */ + { + let mut x = MPNat::from_big_endian(&[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, + 0xff, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + ]); + let y = MPNat::from_big_endian(&[ + 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, + 0x00, 0x00, + ]); + x.sub_to_same_size(&y); + } } #[test]