diff --git a/src/polynomial.rs b/src/polynomial.rs index a6843ea..acba1de 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -57,34 +57,67 @@ impl Mul for Polynomial { fn mul(self, rhs: Polynomial) -> Self::Output { let mut v1 = self.coef; let mut v2 = rhs.coef; - let n = v1.len().max(v2.len()) as i64; - if v1.len() > v2.len() { - swap(&mut v1, &mut v2); - } + let n = (v1.len() + v2.len()) as i64; + let v1_deg = v1.len() - 1; + let v2_deg = v2.len() - 1; + v1 = vec![0; (n - v1.len() as i64) as usize] .into_iter() .chain(v1.into_iter()) .collect(); + v2 = vec![0; (n - v2.len() as i64) as usize] + .into_iter() + .chain(v2.into_iter()) + .collect(); - let M = v1.iter().max().unwrap().pow(2) as i64 * n + 1; + let M = v1 + .iter() + .map(|x| x.abs()) + .max() + .unwrap() + .max(v2.iter().map(|x| x.abs()).max().unwrap()) + .pow(2) as i64 + * n + + 1; let c = working_modulus(n, M); + println!("consts -- {} {:?}", M, c); + + v1.iter_mut().for_each(|x| { + if *x < 0 { + *x = (*x).rem_euclid(M) + } + }); + v2.iter_mut().for_each(|x| { + if *x < 0 { + *x = (*x).rem_euclid(M) + } + }); + + println!("v1 -- {:?}", v1); + println!("v2 -- {:?}", v2); + let a_forward = forward(v1, &c); let b_forward = forward(v2, &c); - let mul = a_forward + let mut mul: Vec = vec![0; n as usize]; + a_forward .iter() .rev() .zip_longest(b_forward.iter().rev()) - .map(|p| match p { - Both(&a, &b) => (a * b) % c.N, - Left(&_a) => 0, - Right(&_b) => 0, - }) - .rev() - .collect::>(); - Polynomial { - coef: inverse(mul, &c), - } + .enumerate() + .for_each(|(i, p)| match p { + Both(&a, &b) => mul[i] = (a * b) % c.N, + Left(_) => {} + Right(_) => {} + }); + mul.reverse(); + let coef = inverse(mul.clone(), &c); + let coef = inverse(mul, &c) + .iter() + .map(|&x| if x > M / 2 { -(M - x.rem_euclid(M)) } else { x }) + .collect::>()[..=(v1_deg + v2_deg)] + .to_vec(); + Polynomial { coef } } } @@ -103,8 +136,12 @@ mod tests { #[test] fn mul() { - let a = Polynomial { coef: vec![1, 2] }; - let b = Polynomial { coef: vec![1] }; + let a = Polynomial { + coef: vec![1, 2, -3], + }; + let b = Polynomial { + coef: vec![1, -5, 4, -8], + }; println!("{:?}", a * b); } }