Skip to content

Commit

Permalink
engine-modexp bug fixes and performance improvements (#809)
Browse files Browse the repository at this point in the history
## Description

This fixes the following issues with `engine-modexp`:

- addition overflow in big_sq that would lead to incorrect output (and a
crash in debug mode)
- assert failure in sub_to_same_size that would lead to incorrect output
- slow modular exponentiation with even modulus. The bottleneck was the
routine were the inverse of A mod M was computed, where M is a power of
two. This inversion function has been replaced with the algorithm found
in [this paper](https://eprint.iacr.org/2017/411.pdf), which is much
faster.

## Performance / NEAR gas cost considerations

The performance of modular exponentiation with even modulus is improved.
Performance remains the same for odd modulus.

## Testing

Extensive fuzz testing, was was also the method used to detect the
issues that this PR addresses.

## How should this be reviewed

## Additional information

---------

Co-authored-by: Michael Birch <michael.birch@aurora.dev>
  • Loading branch information
guidovranken and birchmd authored Aug 1, 2023
1 parent 056c4c5 commit 4ecee7d
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 89 deletions.
94 changes: 37 additions & 57 deletions engine-modexp/src/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,32 +158,6 @@ pub fn mod_inv(x: Word) -> Word {
y
}

// Given x odd, computes `x^(-1) mod 2^(WORD_BYTES*out.digits.len())`.
// See `MODULAR-INVERSE` in https://link.springer.com/content/pdf/10.1007/3-540-46877-3_21.pdf
pub fn big_mod_inv(x: &MPNat, out: &mut MPNat, scratch: &mut [Word]) {
let s = out.digits.len();
out.digits[0] = mod_inv(x.digits[0]);

for digit_index in 1..s {
for i in 1..WORD_BITS {
let mask = (1 << i) - 1;
big_wrapping_mul(x, out, scratch);
scratch[digit_index] &= mask;
let q = 1 << (i - 1);
if scratch[digit_index] >= q {
out.digits[digit_index] += q;
}
scratch.fill(0);
}
big_wrapping_mul(x, out, scratch);
let q = 1 << (WORD_BITS - 1);
if scratch[digit_index] >= q {
out.digits[digit_index] += q;
}
scratch.fill(0);
}
}

/// Computes R mod n, where R = 2^(WORD_BITS*k) and k = n.digits.len()
/// Note that if R = qn + r, q must be smaller than 2^WORD_BITS since `2^(WORD_BITS) * n > R`
/// (adding a whole additional word to n is too much).
Expand Down Expand Up @@ -294,14 +268,22 @@ pub fn big_sq(x: &MPNat, out: &mut [Word]) {
out[i + i] = product;
let mut c = carry as DoubleWord;
for j in (i + 1)..s {
let product = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord);
let (product, overflow) = product.overflowing_add(product);
let sum = (out[i + j] as DoubleWord) + product + c;
out[i + j] = sum as Word;
c = (sum >> WORD_BITS) as DoubleWord;
let mut new_c: DoubleWord = 0;
let res = (x.digits[i] as DoubleWord) * (x.digits[j] as DoubleWord);
let (res, overflow) = res.overflowing_add(res);
if overflow {
c += BASE;
new_c += BASE;
}
let (res, overflow) = (out[i + j] as DoubleWord).overflowing_add(res);
if overflow {
new_c += BASE;
}
let (res, overflow) = res.overflowing_add(c);
if overflow {
new_c += BASE;
}
out[i + j] = res as Word;
c = new_c + ((res >> WORD_BITS) as DoubleWord);
}
let (sum, carry) = carrying_add(out[i + s], c as Word, false);
out[i + s] = sum;
Expand Down Expand Up @@ -351,6 +333,11 @@ pub fn in_place_add(a: &mut [Word], b: &[Word]) -> bool {
pub fn in_place_mul_sub(a: &mut [Word], x: &[Word], y: Word) -> Word {
debug_assert!(a.len() == x.len());

// a -= x*0 leaves a unchanged, so return early
if y == 0 {
return 0;
}

// carry is between -big_digit::MAX and 0, so to avoid overflow we store
// offset_carry = carry + big_digit::MAX
let mut offset_carry = Word::MAX;
Expand Down Expand Up @@ -504,31 +491,6 @@ fn test_r_mod_n() {
}
}

#[test]
fn test_big_mod_inv() {
check_big_mod_inv(0x02_FF_FF_FF);
check_big_mod_inv(0x1234_0000_DDDD_FFFF);
check_big_mod_inv(0x52DA_9A91_F82D_6E17_FDF8_6743_2B58_7917);

fn check_big_mod_inv(n: u128) {
let x = MPNat::from_big_endian(&n.to_be_bytes());
let s = x.digits.len();
let mut result = MPNat { digits: vec![0; s] };
let mut scratch = vec![0; s];
big_mod_inv(&x, &mut result, &mut scratch);
let n_inv = mp_nat_to_u128(&result);
if WORD_BITS * s < u128::BITS as usize {
assert_eq!(
n.wrapping_mul(n_inv) % (1 << (WORD_BITS * s)),
1,
"{n} failed big_mod_inv check"
);
} else {
assert_eq!(n.wrapping_mul(n_inv), 1, "{n} failed big_mod_inv check");
}
}
}

#[test]
fn test_in_place_shl() {
check_in_place_shl(0, 0);
Expand Down Expand Up @@ -655,6 +617,24 @@ fn test_big_sq() {
};
assert_eq!(result, expected, "{a}^2 != {expected}");
}

/* Test for addition overflows in the big_sq inner loop */
{
let x = MPNat::from_big_endian(&[
0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x40, 0x00,
0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x80, 0x00, 0x00, 0x00,
]);
let mut out = vec![0; 2 * x.digits.len() + 1];
big_sq(&x, &mut out);
let result = MPNat { digits: out }.to_big_endian();
let expected = vec![
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0xff, 0xff, 0xff, 0xfe, 0x40, 0x00, 0x00, 0x01, 0x90, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0xbf, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
assert_eq!(result, expected);
}
}

#[test]
Expand Down
140 changes: 108 additions & 32 deletions engine-modexp/src/mpnat.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
arith::{
big_mod_inv, big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add,
compute_r_mod_n, in_place_add, in_place_mul_sub, in_place_shl, in_place_shr,
join_as_double, mod_inv, monpro, monsq,
big_wrapping_mul, big_wrapping_pow, borrowing_sub, carrying_add, compute_r_mod_n,
in_place_add, in_place_mul_sub, in_place_shl, in_place_shr, join_as_double, mod_inv,
monpro, monsq,
},
maybe_std::{vec, Vec},
};
Expand All @@ -22,6 +22,61 @@ pub struct MPNat {
}

impl MPNat {
// Koç's algorithm for inversion mod 2^k
// https://eprint.iacr.org/2017/411.pdf
fn koc_2017_inverse(aa: &Self, k: usize) -> Self {
debug_assert!(aa.is_odd());

let length = k / WORD_BITS;
let mut b = MPNat {
digits: vec![0; length + 1],
};
b.digits[0] = 1;

let mut a = MPNat {
digits: aa.digits.clone(),
};
a.digits.resize(length + 1, 0);

let mut neg: bool = false;

let mut res = MPNat {
digits: vec![0; length + 1],
};

let (mut wordpos, mut bitpos) = (0, 0);

for _ in 0..k {
let x = b.digits[0] & 1;
if x != 0 {
if !neg {
// b = a - b
let mut tmp = MPNat {
digits: a.digits.clone(),
};
in_place_mul_sub(&mut tmp.digits, &b.digits, 1);
b = tmp;
neg = true;
} else {
// b = b - a
in_place_add(&mut b.digits, &a.digits);
}
}

in_place_shr(&mut b.digits, 1);

res.digits[wordpos] |= x << bitpos;

bitpos += 1;
if bitpos == WORD_BITS {
bitpos = 0;
wordpos += 1;
}
}

res
}

pub fn from_big_endian(bytes: &[u8]) -> Self {
if bytes.is_empty() {
return Self { digits: vec![0] };
Expand Down Expand Up @@ -174,14 +229,11 @@ impl MPNat {
let x1 = base_copy.modpow_montgomery(exp, &odd);
let x2 = self.modpow_with_power_of_two(exp, &power_of_two);

let odd_inv =
Self::koc_2017_inverse(&odd, trailing_zeros * WORD_BITS + additional_zero_bits);

let s = power_of_two.digits.len();
let mut scratch = vec![0; s];
let odd_inv = {
let mut tmp = MPNat { digits: vec![0; s] };
big_mod_inv(&odd, &mut tmp, &mut scratch);
*tmp.digits.last_mut().unwrap() &= power_of_two_mask;
tmp
};
let diff = {
scratch.fill(0);
let mut b = false;
Expand Down Expand Up @@ -348,7 +400,7 @@ impl MPNat {
return;
}

let other_most_sig = *other.digits.last().unwrap();
let other_most_sig = *other.digits.last().unwrap() as DoubleWord;

if self.digits.len() == 2 {
// This is the smallest case since `n >= 1` and `m > 0`
Expand All @@ -357,7 +409,7 @@ impl MPNat {
// to get the answer directly.
let self_most_sig = self.digits.pop().unwrap();
let a = join_as_double(self_most_sig, self.digits[0]);
let b = other_most_sig as DoubleWord;
let b = other_most_sig;
self.digits[0] = (a % b) as Word;
return;
}
Expand All @@ -369,8 +421,7 @@ impl MPNat {
for j in (0..k).rev() {
let self_most_sig = self.digits.pop().unwrap();
let self_second_sig = self.digits[j];
let r =
join_as_double(self_most_sig, self_second_sig) % (other_most_sig as DoubleWord);
let r = join_as_double(self_most_sig, self_second_sig) % other_most_sig;
self.digits[j] = r as Word;
}
return;
Expand All @@ -385,7 +436,7 @@ impl MPNat {
// both numerator and denominator by a common factor
// and run the algorithm on those numbers.
// See Knuth The Art of Computer Programming vol. 2 section 4.3 for details.
let shift = other_most_sig.leading_zeros();
let shift = (other_most_sig as Word).leading_zeros();
if shift > 0 {
// Normalize self
let overflow = in_place_shl(&mut self.digits, shift);
Expand All @@ -410,34 +461,31 @@ impl MPNat {
return;
}

let other_second_sig = other.digits[n - 2];
let other_second_sig = other.digits[n - 2] as DoubleWord;
let mut self_most_sig: Word = 0;
for j in (0..=m).rev() {
let self_second_sig = *self.digits.last().unwrap();
let self_third_sig = self.digits[self.digits.len() - 2];

let (mut q_hat, mut r_hat) = {
let a = join_as_double(self_most_sig, self_second_sig);
let mut q_hat = a / (other_most_sig as DoubleWord);
let mut r_hat = a % (other_most_sig as DoubleWord);
let a = join_as_double(self_most_sig, self_second_sig);
let mut q_hat = a / other_most_sig;
let mut r_hat = a % other_most_sig;

if q_hat == BASE {
loop {
let a = q_hat * other_second_sig;
let b = join_as_double(r_hat as Word, self_third_sig);
if q_hat >= BASE || a > b {
q_hat -= 1;
r_hat += other_most_sig as DoubleWord;
r_hat += other_most_sig;
if BASE <= r_hat {
break;
}
} else {
break;
}

(q_hat as Word, r_hat)
};

while r_hat < BASE
&& join_as_double(r_hat as Word, self_third_sig)
< (q_hat as DoubleWord) * (other_second_sig as DoubleWord)
{
q_hat -= 1;
r_hat += other_most_sig as DoubleWord;
}

let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat);
let mut borrow = in_place_mul_sub(&mut self.digits[j..], &other.digits, q_hat as Word);
if borrow > self_most_sig {
// q_hat was too large, add back one multiple of the modulus
let carry = in_place_add(&mut self.digits[j..], &other.digits);
Expand Down Expand Up @@ -660,6 +708,34 @@ fn test_sub_to_same_size() {
let result = crate::arith::mp_nat_to_u128(&x);
assert_eq!(result % n, a % n, "{a} % {n} failed sub_to_same_size check");
}

/* Test that borrow equals self_most_sig at end of sub_to_same_size */
{
let mut x = MPNat::from_big_endian(&[
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02,
0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x02, 0xc9, 0xd9, 0x12, 0x61, 0x8e, 0xf5, 0x02, 0x2c,
0xa0, 0x00, 0x00, 0x00,
]);
let y = MPNat::from_big_endian(&[
0xae, 0x5f, 0xf0, 0x8b, 0xfc, 0x02, 0x71, 0xa4, 0xfe, 0xe0, 0x49, 0x0f, 0x70, 0x00,
0x00, 0x00,
]);
x.sub_to_same_size(&y);
}

/* Additional test for sub_to_same_size q_hat/r_hat adjustment logic */
{
let mut x = MPNat::from_big_endian(&[
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff,
0xff, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
]);
let y = MPNat::from_big_endian(&[
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00,
0x00, 0x00,
]);
x.sub_to_same_size(&y);
}
}

#[test]
Expand Down

0 comments on commit 4ecee7d

Please sign in to comment.