Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

engine-modexp bug fixes and performance improvements #809

Merged
merged 8 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 37 additions & 57 deletions engine-modexp/src/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,32 +158,6 @@ pub fn mod_inv(x: Word) -> Word {
y
}

// Given x odd, computes `x^(-1) mod 2^(WORD_BYTES*out.digits.len())`.
// See `MODULAR-INVERSE` in https://link.springer.com/content/pdf/10.1007/3-540-46877-3_21.pdf
pub fn big_mod_inv(x: &MPNat, out: &mut MPNat, scratch: &mut [Word]) {
let s = out.digits.len();
out.digits[0] = mod_inv(x.digits[0]);

for digit_index in 1..s {
for i in 1..WORD_BITS {
let mask = (1 << i) - 1;
big_wrapping_mul(x, out, scratch);
scratch[digit_index] &= mask;
let q = 1 << (i - 1);
if scratch[digit_index] >= q {
out.digits[digit_index] += q;
}
scratch.fill(0);
}
big_wrapping_mul(x, out, scratch);
let q = 1 << (WORD_BITS - 1);
if scratch[digit_index] >= q {
out.digits[digit_index] += q;
}
scratch.fill(0);
}
}

/// Computes R mod n, where R = 2^(WORD_BITS*k) and k = n.digits.len()
/// Note that if R = qn + r, q must be smaller than 2^WORD_BITS since `2^(WORD_BITS) * n > R`
/// (adding a whole additional word to n is too much).
Expand Down Expand Up @@ -294,14 +268,22 @@ pub fn big_sq(x: &MPNat, out: &mut [Word]) {
out[i + i] = product;
let mut c = carry as DoubleWord;
for j in (i + 1)..s {
let product = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord);
let (product, overflow) = product.overflowing_add(product);
let sum = (out[i + j] as DoubleWord) + product + c;
out[i + j] = sum as Word;
c = (sum >> WORD_BITS) as DoubleWord;
let mut new_c: DoubleWord = 0;
let res = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord);
let (res, overflow) = res.overflowing_add(res);
if overflow {
c += BASE;
new_c += BASE;
}
let (res, overflow) = (out[i + j] as DoubleWord).overflowing_add(res);
if overflow {
new_c += BASE;
}
let (res, overflow) = res.overflowing_add(c);
if overflow {
new_c += BASE;
}
out[i + j] = res as Word;
c = new_c + ((res >> WORD_BITS) as DoubleWord);
}
let (sum, carry) = carrying_add(out[i + s], c as Word, false);
out[i + s] = sum;
Expand Down Expand Up @@ -351,6 +333,11 @@ pub fn in_place_add(a: &mut [Word], b: &[Word]) -> bool {
pub fn in_place_mul_sub(a: &mut [Word], x: &[Word], y: Word) -> Word {
debug_assert!(a.len() == x.len());

// a -= x*0 leaves a unchanged, so return early
if y == 0 {
return 0;
}

// carry is between -big_digit::MAX and 0, so to avoid overflow we store
// offset_carry = carry + big_digit::MAX
let mut offset_carry = Word::MAX;
Expand Down Expand Up @@ -504,31 +491,6 @@ fn test_r_mod_n() {
}
}

#[test]
fn test_big_mod_inv() {
check_big_mod_inv(0x02_FF_FF_FF);
check_big_mod_inv(0x1234_0000_DDDD_FFFF);
check_big_mod_inv(0x52DA_9A91_F82D_6E17_FDF8_6743_2B58_7917);

fn check_big_mod_inv(n: u128) {
let x = MPNat::from_big_endian(&n.to_be_bytes());
let s = x.digits.len();
let mut result = MPNat { digits: vec![0; s] };
let mut scratch = vec![0; s];
big_mod_inv(&x, &mut result, &mut scratch);
let n_inv = mp_nat_to_u128(&result);
if WORD_BITS * s < u128::BITS as usize {
assert_eq!(
n.wrapping_mul(n_inv) % (1 << (WORD_BITS * s)),
1,
"{n} failed big_mod_inv check"
);
} else {
assert_eq!(n.wrapping_mul(n_inv), 1, "{n} failed big_mod_inv check");
}
}
}

#[test]
fn test_in_place_shl() {
check_in_place_shl(0, 0);
Expand Down Expand Up @@ -655,6 +617,24 @@ fn test_big_sq() {
};
assert_eq!(result, expected, "{a}^2 != {expected}");
}

/* Test for addition overflows in the big_sq inner loop */
{
let x = MPNat::from_big_endian(&[
0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x40, 0x00,
0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00,
]);
let mut out = vec![0; 2 * x.digits.len() + 1];
big_sq(&x, &mut out);
let result = MPNat { digits: out }.to_big_endian();
let expected = vec![
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0xff, 0xff, 0xff, 0xfe, 0x40, 0x00, 0x00, 0x01, 0x90, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0xbf, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
assert_eq!(result, expected);
}
}

#[test]
Expand Down
140 changes: 108 additions & 32 deletions engine-modexp/src/mpnat.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
arith::{
big_mod_inv, big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add,
compute_r_mod_n, in_place_add, in_place_mul_sub, in_place_shl, in_place_shr,
join_as_double, mod_inv, monpro, monsq,
big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add, compute_r_mod_n,
in_place_add, in_place_mul_sub, in_place_shl, in_place_shr, join_as_double, mod_inv,
monpro, monsq,
},
maybe_std::{vec, Vec},
};
Expand All @@ -22,6 +22,61 @@ pub struct MPNat {
}

impl MPNat {
// Koç's algorithm for inversion mod 2^k
// https://eprint.iacr.org/2017/411.pdf
fn koc_2017_inverse(aa: &Self, k: usize) -> Self {
debug_assert!(aa.is_odd());

let length = k / WORD_BITS;
let mut b = MPNat {
digits: vec![0; length + 1],
};
b.digits[0] = 1;

let mut a = MPNat {
digits: aa.digits.clone(),
};
a.digits.resize(length + 1, 0);

let mut neg: bool = false;

let mut res = MPNat {
digits: vec![0; length + 1],
};

let (mut wordpos, mut bitpos) = (0, 0);

for _ in 0..k {
let x = b.digits[0] & 1;
if x != 0 {
if !neg {
// b = a - b
let mut tmp = MPNat {
digits: a.digits.clone(),
};
in_place_mul_sub(&mut tmp.digits, &b.digits, 1);
b = tmp;
neg = true;
} else {
// b = b - a
in_place_add(&mut b.digits, &a.digits);
}
}

in_place_shr(&mut b.digits, 1);

res.digits[wordpos] |= x << bitpos;

bitpos += 1;
if bitpos == WORD_BITS {
bitpos = 0;
wordpos += 1;
}
}

res
}

pub fn from_big_endian(bytes: &[u8]) -> Self {
if bytes.is_empty() {
return Self { digits: vec![0] };
Expand Down Expand Up @@ -174,14 +229,11 @@ impl MPNat {
let x1 = base_copy.modpow_montgomery(exp, &odd);
let x2 = self.modpow_with_power_of_two(exp, &power_of_two);

let odd_inv =
Self::koc_2017_inverse(&odd, trailing_zeros * WORD_BITS + additional_zero_bits);

let s = power_of_two.digits.len();
let mut scratch = vec![0; s];
let odd_inv = {
let mut tmp = MPNat { digits: vec![0; s] };
big_mod_inv(&odd, &mut tmp, &mut scratch);
*tmp.digits.last_mut().unwrap() &= power_of_two_mask;
tmp
};
let diff = {
scratch.fill(0);
let mut b = false;
Expand Down Expand Up @@ -348,7 +400,7 @@ impl MPNat {
return;
}

let other_most_sig = *other.digits.last().unwrap();
let other_most_sig = *other.digits.last().unwrap() as DoubleWord;

if self.digits.len() == 2 {
// This is the smallest case since `n >= 1` and `m > 0`
Expand All @@ -357,7 +409,7 @@ impl MPNat {
// to get the answer directly.
let self_most_sig = self.digits.pop().unwrap();
let a = join_as_double(self_most_sig, self.digits[0]);
let b = other_most_sig as DoubleWord;
let b = other_most_sig;
self.digits[0] = (a % b) as Word;
return;
}
Expand All @@ -369,8 +421,7 @@ impl MPNat {
for j in (0..k).rev() {
let self_most_sig = self.digits.pop().unwrap();
let self_second_sig = self.digits[j];
let r =
join_as_double(self_most_sig, self_second_sig) % (other_most_sig as DoubleWord);
let r = join_as_double(self_most_sig, self_second_sig) % other_most_sig;
self.digits[j] = r as Word;
}
return;
Expand All @@ -385,7 +436,7 @@ impl MPNat {
// both numerator and denominator by a common factor
// and run the algorithm on those numbers.
// See Knuth The Art of Computer Programming vol. 2 section 4.3 for details.
let shift = other_most_sig.leading_zeros();
let shift = (other_most_sig as Word).leading_zeros();
if shift > 0 {
// Normalize self
let overflow = in_place_shl(&mut self.digits, shift);
Expand All @@ -410,34 +461,31 @@ impl MPNat {
return;
}

let other_second_sig = other.digits[n - 2];
let other_second_sig = other.digits[n - 2] as DoubleWord;
let mut self_most_sig: Word = 0;
for j in (0..=m).rev() {
let self_second_sig = *self.digits.last().unwrap();
let self_third_sig = self.digits[self.digits.len() - 2];

let (mut q_hat, mut r_hat) = {
let a = join_as_double(self_most_sig, self_second_sig);
let mut q_hat = a / (other_most_sig as DoubleWord);
let mut r_hat = a % (other_most_sig as DoubleWord);
let a = join_as_double(self_most_sig, self_second_sig);
let mut q_hat = a / other_most_sig;
let mut r_hat = a % other_most_sig;

if q_hat == BASE {
loop {
let a = q_hat * other_second_sig;
let b = join_as_double(r_hat as Word, self_third_sig);
if q_hat >= BASE || a > b {
q_hat -= 1;
r_hat += other_most_sig as DoubleWord;
r_hat += other_most_sig;
if BASE <= r_hat {
break;
}
} else {
break;
}

(q_hat as Word, r_hat)
};

while r_hat < BASE
&& join_as_double(r_hat as Word, self_third_sig)
< (q_hat as DoubleWord) * (other_second_sig as DoubleWord)
{
q_hat -= 1;
r_hat += other_most_sig as DoubleWord;
}

let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat);
let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat as Word);
if borrow > self_most_sig {
// q_hat was too large, add back one multiple of the modulus
let carry = in_place_add(&mut self.digits[j..], &other.digits);
Expand Down Expand Up @@ -660,6 +708,34 @@ fn test_sub_to_same_size() {
let result = crate::arith::mp_nat_to_u128(&x);
assert_eq!(result % n, a % n, "{a} % {n} failed sub_to_same_size check");
}

/* Test that borrow equals self_most_sig at end of sub_to_same_size */
{
let mut x = MPNat::from_big_endian(&[
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02,
0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x02, 0xc9, 0xd9, 0x12, 0x61, 0x8e, 0xf5, 0x02, 0x2c,
0xa0, 0x00, 0x00, 0x00,
]);
let y = MPNat::from_big_endian(&[
0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02, 0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x0f, 0x70, 0x00,
0x00, 0x00,
]);
x.sub_to_same_size(&y);
}

/* Additional test for sub_to_same_size q_hat/r_hat adjustment logic */
{
let mut x = MPNat::from_big_endian(&[
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
]);
let y = MPNat::from_big_endian(&[
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00,
0x00, 0x00,
]);
x.sub_to_same_size(&y);
}
}

#[test]
Expand Down
Loading