1use crate::{
4 div_limb::mul_rem,
5 modular::{MontyForm, MontyParams},
6 Concat, Limb, MulMod, NonZero, Split, Uint, WideWord, Word,
7};
8
9impl<const LIMBS: usize> Uint<LIMBS> {
10 pub fn mul_mod<const WIDE_LIMBS: usize>(
14 &self,
15 rhs: &Uint<LIMBS>,
16 p: &NonZero<Uint<LIMBS>>,
17 ) -> Uint<LIMBS>
18 where
19 Uint<LIMBS>: Concat<Output = Uint<WIDE_LIMBS>>,
20 Uint<WIDE_LIMBS>: Split<Output = Uint<LIMBS>>,
21 {
22 let params = MontyParams::new(p.to_odd().expect("p should be odd"));
29 (MontyForm::new(self, params) * MontyForm::new(rhs, params)).retrieve()
30 }
31
32 pub fn mul_mod_vartime(&self, rhs: &Uint<LIMBS>, p: &NonZero<Uint<LIMBS>>) -> Uint<LIMBS> {
34 let lo_hi = self.split_mul(rhs);
35 Self::rem_wide_vartime(lo_hi, p)
36 }
37
38 pub const fn mul_mod_special(&self, rhs: &Self, c: Limb) -> Self {
45 if LIMBS == 1 {
48 let reduced = mul_rem(
49 self.limbs[0],
50 rhs.limbs[0],
51 NonZero::<Limb>::new_unwrap(Limb(Word::MIN.wrapping_sub(c.0))),
52 );
53 return Self::from_word(reduced.0);
54 }
55
56 let (lo, hi) = self.split_mul(rhs);
57
58 let (lo, carry) = mac_by_limb(&lo, &hi, c, Limb::ZERO);
60
61 let (lo, carry) = {
62 let rhs = (carry.0 + 1) as WideWord * c.0 as WideWord;
63 lo.adc(&Self::from_wide_word(rhs), Limb::ZERO)
64 };
65
66 let (lo, _) = {
67 let rhs = carry.0.wrapping_sub(1) & c.0;
68 lo.sbb(&Self::from_word(rhs), Limb::ZERO)
69 };
70
71 lo
72 }
73}
74
75impl<const LIMBS: usize> MulMod for Uint<LIMBS> {
76 type Output = Self;
77
78 fn mul_mod(&self, rhs: &Self, p: &Self) -> Self {
79 self.mul_mod_vartime(rhs, &NonZero::new(*p).expect("p should be non-zero"))
80 }
81}
82
83const fn mac_by_limb<const LIMBS: usize>(
85 a: &Uint<LIMBS>,
86 b: &Uint<LIMBS>,
87 c: Limb,
88 carry: Limb,
89) -> (Uint<LIMBS>, Limb) {
90 let mut i = 0;
91 let mut a = *a;
92 let mut carry = carry;
93
94 while i < LIMBS {
95 (a.limbs[i], carry) = a.limbs[i].mac(b.limbs[i], c, carry);
96 i += 1;
97 }
98
99 (a, carry)
100}
101
102#[cfg(all(test, feature = "rand"))]
103mod tests {
104 use crate::{Limb, NonZero, Random, RandomMod, Uint};
105 use rand_core::SeedableRng;
106
107 macro_rules! test_mul_mod_special {
108 ($size:expr, $test_name:ident) => {
109 #[test]
110 fn $test_name() {
111 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);
112 let moduli = [
113 NonZero::<Limb>::random(&mut rng),
114 NonZero::<Limb>::random(&mut rng),
115 ];
116
117 for special in &moduli {
118 let p =
119 &NonZero::new(Uint::ZERO.wrapping_sub(&Uint::from(special.get()))).unwrap();
120
121 let minus_one = p.wrapping_sub(&Uint::ONE);
122
123 let base_cases = [
124 (Uint::ZERO, Uint::ZERO, Uint::ZERO),
125 (Uint::ONE, Uint::ZERO, Uint::ZERO),
126 (Uint::ZERO, Uint::ONE, Uint::ZERO),
127 (Uint::ONE, Uint::ONE, Uint::ONE),
128 (minus_one, minus_one, Uint::ONE),
129 (minus_one, Uint::ONE, minus_one),
130 (Uint::ONE, minus_one, minus_one),
131 ];
132 for (a, b, c) in &base_cases {
133 let x = a.mul_mod_special(&b, *special.as_ref());
134 assert_eq!(*c, x, "{} * {} mod {} = {} != {}", a, b, p, x, c);
135 }
136
137 for _i in 0..100 {
138 let a = Uint::<$size>::random_mod(&mut rng, p);
139 let b = Uint::<$size>::random_mod(&mut rng, p);
140
141 let c = a.mul_mod_special(&b, *special.as_ref());
142 assert!(c < **p, "not reduced: {} >= {} ", c, p);
143
144 let expected = {
145 let (lo, hi) = a.split_mul(&b);
146 let mut prod = Uint::<{ 2 * $size }>::ZERO;
147 prod.limbs[..$size].clone_from_slice(&lo.limbs);
148 prod.limbs[$size..].clone_from_slice(&hi.limbs);
149 let mut modulus = Uint::ZERO;
150 modulus.limbs[..$size].clone_from_slice(&p.as_ref().limbs);
151 let reduced = prod.rem_vartime(&NonZero::new(modulus).unwrap());
152 let mut expected = Uint::ZERO;
153 expected.limbs[..].clone_from_slice(&reduced.limbs[..$size]);
154 expected
155 };
156 assert_eq!(c, expected, "incorrect result");
157 }
158 }
159 }
160 };
161 }
162
163 test_mul_mod_special!(1, mul_mod_special_1);
164 test_mul_mod_special!(2, mul_mod_special_2);
165 test_mul_mod_special!(3, mul_mod_special_3);
166 test_mul_mod_special!(4, mul_mod_special_4);
167 test_mul_mod_special!(5, mul_mod_special_5);
168 test_mul_mod_special!(6, mul_mod_special_6);
169 test_mul_mod_special!(7, mul_mod_special_7);
170 test_mul_mod_special!(8, mul_mod_special_8);
171 test_mul_mod_special!(9, mul_mod_special_9);
172 test_mul_mod_special!(10, mul_mod_special_10);
173 test_mul_mod_special!(11, mul_mod_special_11);
174 test_mul_mod_special!(12, mul_mod_special_12);
175}