crypto_bigint/modular/
monty_form.rs

1//! Implements `MontyForm`s, supporting modular arithmetic with a modulus set at runtime.
2
3mod add;
4pub(super) mod inv;
5mod lincomb;
6mod mul;
7mod neg;
8mod pow;
9mod sub;
10
11use super::{
12    const_monty_form::{ConstMontyForm, ConstMontyParams},
13    div_by_2::div_by_2,
14    reduction::montgomery_reduction,
15    Retrieve,
16};
17use crate::{Concat, Limb, Monty, NonZero, Odd, Split, Uint, Word};
18use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
19
20/// Parameters to efficiently go to/from the Montgomery form for an odd modulus provided at runtime.
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22pub struct MontyParams<const LIMBS: usize> {
23    /// The constant modulus
24    modulus: Odd<Uint<LIMBS>>,
25    /// 1 in Montgomery form (a.k.a. `R`)
26    one: Uint<LIMBS>,
27    /// `R^2 mod modulus`, used to move into Montgomery form
28    r2: Uint<LIMBS>,
29    /// `R^3 mod modulus`, used to compute the multiplicative inverse
30    r3: Uint<LIMBS>,
31    /// The lowest limbs of -(MODULUS^-1) mod R
32    /// We only need the LSB because during reduction this value is multiplied modulo 2**Limb::BITS.
33    mod_neg_inv: Limb,
34    /// Leading zeros in the modulus, used to choose optimized algorithms
35    mod_leading_zeros: u32,
36}
37
38impl<const LIMBS: usize, const WIDE_LIMBS: usize> MontyParams<LIMBS>
39where
40    Uint<LIMBS>: Concat<Output = Uint<WIDE_LIMBS>>,
41    Uint<WIDE_LIMBS>: Split<Output = Uint<LIMBS>>,
42{
43    /// Instantiates a new set of `MontyParams` representing the given odd `modulus`.
44    pub fn new(modulus: Odd<Uint<LIMBS>>) -> Self {
45        // `R mod modulus` where `R = 2^BITS`.
46        // Represents 1 in Montgomery form.
47        let one = Uint::MAX.rem(modulus.as_nz_ref()).wrapping_add(&Uint::ONE);
48
49        // `R^2 mod modulus`, used to convert integers to Montgomery form.
50        let r2 = one
51            .square()
52            .rem(&NonZero(modulus.0.concat(&Uint::ZERO)))
53            .split()
54            .0;
55
56        // The modular inverse should always exist, because it was ensured odd above, which also ensures it's non-zero
57        let inv_mod = modulus
58            .inv_mod2k(Word::BITS)
59            .expect("modular inverse should exist");
60
61        let mod_neg_inv = Limb(Word::MIN.wrapping_sub(inv_mod.limbs[0].0));
62
63        let mod_leading_zeros = modulus.as_ref().leading_zeros().min(Word::BITS - 1);
64
65        // `R^3 mod modulus`, used for inversion in Montgomery form.
66        let r3 = montgomery_reduction(&r2.square_wide(), &modulus, mod_neg_inv);
67
68        Self {
69            modulus,
70            one,
71            r2,
72            r3,
73            mod_neg_inv,
74            mod_leading_zeros,
75        }
76    }
77}
78
79impl<const LIMBS: usize> MontyParams<LIMBS> {
80    /// Instantiates a new set of `MontyParams` representing the given odd `modulus`.
81    pub fn new_vartime(modulus: Odd<Uint<LIMBS>>) -> Self {
82        // `R mod modulus` where `R = 2^BITS`.
83        // Represents 1 in Montgomery form.
84        let one = Uint::MAX
85            .rem_vartime(modulus.as_nz_ref())
86            .wrapping_add(&Uint::ONE);
87
88        // `R^2 mod modulus`, used to convert integers to Montgomery form.
89        let r2 = Uint::rem_wide_vartime(one.square_wide(), modulus.as_nz_ref());
90
91        // The modular inverse should always exist, because it was ensured odd above, which also ensures it's non-zero
92        let inv_mod = modulus
93            .inv_mod2k_vartime(Word::BITS)
94            .expect("modular inverse should exist");
95
96        let mod_neg_inv = Limb(Word::MIN.wrapping_sub(inv_mod.limbs[0].0));
97
98        let mod_leading_zeros = modulus.as_ref().leading_zeros_vartime().min(Word::BITS - 1);
99
100        // `R^3 mod modulus`, used for inversion in Montgomery form.
101        let r3 = montgomery_reduction(&r2.square_wide(), &modulus, mod_neg_inv);
102
103        Self {
104            modulus,
105            one,
106            r2,
107            r3,
108            mod_neg_inv,
109            mod_leading_zeros,
110        }
111    }
112
113    /// Returns the modulus which was used to initialize these parameters.
114    pub const fn modulus(&self) -> &Odd<Uint<LIMBS>> {
115        &self.modulus
116    }
117
118    /// Create `MontyParams` corresponding to a `ConstMontyParams`.
119    pub const fn from_const_params<P>() -> Self
120    where
121        P: ConstMontyParams<LIMBS>,
122    {
123        Self {
124            modulus: Odd(P::MODULUS.0),
125            one: P::ONE,
126            r2: P::R2,
127            r3: P::R3,
128            mod_neg_inv: P::MOD_NEG_INV,
129            mod_leading_zeros: P::MOD_LEADING_ZEROS,
130        }
131    }
132}
133
134impl<const LIMBS: usize> ConditionallySelectable for MontyParams<LIMBS> {
135    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
136        Self {
137            modulus: Odd::conditional_select(&a.modulus, &b.modulus, choice),
138            one: Uint::conditional_select(&a.one, &b.one, choice),
139            r2: Uint::conditional_select(&a.r2, &b.r2, choice),
140            r3: Uint::conditional_select(&a.r3, &b.r3, choice),
141            mod_neg_inv: Limb::conditional_select(&a.mod_neg_inv, &b.mod_neg_inv, choice),
142            mod_leading_zeros: u32::conditional_select(
143                &a.mod_leading_zeros,
144                &b.mod_leading_zeros,
145                choice,
146            ),
147        }
148    }
149}
150
151impl<const LIMBS: usize> ConstantTimeEq for MontyParams<LIMBS> {
152    fn ct_eq(&self, other: &Self) -> Choice {
153        self.modulus.ct_eq(&other.modulus)
154            & self.one.ct_eq(&other.one)
155            & self.r2.ct_eq(&other.r2)
156            & self.r3.ct_eq(&other.r3)
157            & self.mod_neg_inv.ct_eq(&other.mod_neg_inv)
158    }
159}
160
161#[cfg(feature = "zeroize")]
162impl<const LIMBS: usize> zeroize::Zeroize for MontyParams<LIMBS> {
163    fn zeroize(&mut self) {
164        self.modulus.zeroize();
165        self.one.zeroize();
166        self.r2.zeroize();
167        self.r3.zeroize();
168        self.mod_neg_inv.zeroize();
169        self.mod_leading_zeros.zeroize();
170    }
171}
172
173/// An integer in Montgomery form represented using `LIMBS` limbs.
174/// The odd modulus is set at runtime.
175#[derive(Debug, Clone, Copy, PartialEq, Eq)]
176pub struct MontyForm<const LIMBS: usize> {
177    montgomery_form: Uint<LIMBS>,
178    params: MontyParams<LIMBS>,
179}
180
181impl<const LIMBS: usize> MontyForm<LIMBS> {
182    /// Instantiates a new `MontyForm` that represents this `integer` mod `MOD`.
183    pub const fn new(integer: &Uint<LIMBS>, params: MontyParams<LIMBS>) -> Self {
184        let product = integer.split_mul(&params.r2);
185        let montgomery_form = montgomery_reduction(&product, &params.modulus, params.mod_neg_inv);
186
187        Self {
188            montgomery_form,
189            params,
190        }
191    }
192
193    /// Retrieves the integer currently encoded in this `MontyForm`, guaranteed to be reduced.
194    pub const fn retrieve(&self) -> Uint<LIMBS> {
195        montgomery_reduction(
196            &(self.montgomery_form, Uint::ZERO),
197            &self.params.modulus,
198            self.params.mod_neg_inv,
199        )
200    }
201
202    /// Instantiates a new `MontyForm` that represents zero.
203    pub const fn zero(params: MontyParams<LIMBS>) -> Self {
204        Self {
205            montgomery_form: Uint::<LIMBS>::ZERO,
206            params,
207        }
208    }
209
210    /// Instantiates a new `MontyForm` that represents 1.
211    pub const fn one(params: MontyParams<LIMBS>) -> Self {
212        Self {
213            montgomery_form: params.one,
214            params,
215        }
216    }
217
218    /// Returns the parameter struct used to initialize this object.
219    pub const fn params(&self) -> &MontyParams<LIMBS> {
220        &self.params
221    }
222
223    /// Access the `MontyForm` value in Montgomery form.
224    pub const fn as_montgomery(&self) -> &Uint<LIMBS> {
225        &self.montgomery_form
226    }
227
228    /// Mutably access the `MontyForm` value in Montgomery form.
229    pub fn as_montgomery_mut(&mut self) -> &mut Uint<LIMBS> {
230        &mut self.montgomery_form
231    }
232
233    /// Create a `MontyForm` from a value in Montgomery form.
234    pub const fn from_montgomery(integer: Uint<LIMBS>, params: MontyParams<LIMBS>) -> Self {
235        Self {
236            montgomery_form: integer,
237            params,
238        }
239    }
240
241    /// Extract the value from the `MontyForm` in Montgomery form.
242    pub const fn to_montgomery(&self) -> Uint<LIMBS> {
243        self.montgomery_form
244    }
245
246    /// Performs division by 2, that is returns `x` such that `x + x = self`.
247    pub const fn div_by_2(&self) -> Self {
248        Self {
249            montgomery_form: div_by_2(&self.montgomery_form, &self.params.modulus),
250            params: self.params,
251        }
252    }
253}
254
255impl<const LIMBS: usize> Retrieve for MontyForm<LIMBS> {
256    type Output = Uint<LIMBS>;
257    fn retrieve(&self) -> Self::Output {
258        self.retrieve()
259    }
260}
261
262impl<const LIMBS: usize> Monty for MontyForm<LIMBS> {
263    type Integer = Uint<LIMBS>;
264    type Params = MontyParams<LIMBS>;
265
266    fn new_params_vartime(modulus: Odd<Self::Integer>) -> Self::Params {
267        MontyParams::new_vartime(modulus)
268    }
269
270    fn new(value: Self::Integer, params: Self::Params) -> Self {
271        MontyForm::new(&value, params)
272    }
273
274    fn zero(params: Self::Params) -> Self {
275        MontyForm::zero(params)
276    }
277
278    fn one(params: Self::Params) -> Self {
279        MontyForm::one(params)
280    }
281
282    fn params(&self) -> &Self::Params {
283        &self.params
284    }
285
286    fn as_montgomery(&self) -> &Self::Integer {
287        &self.montgomery_form
288    }
289
290    fn double(&self) -> Self {
291        MontyForm::double(self)
292    }
293
294    fn div_by_2(&self) -> Self {
295        MontyForm::div_by_2(self)
296    }
297
298    fn lincomb_vartime(products: &[(&Self, &Self)]) -> Self {
299        MontyForm::lincomb_vartime(products)
300    }
301}
302
303impl<const LIMBS: usize, P: ConstMontyParams<LIMBS>> From<&ConstMontyForm<P, LIMBS>>
304    for MontyForm<LIMBS>
305{
306    fn from(const_monty_form: &ConstMontyForm<P, LIMBS>) -> Self {
307        Self {
308            montgomery_form: const_monty_form.to_montgomery(),
309            params: MontyParams::from_const_params::<P>(),
310        }
311    }
312}
313
314impl<const LIMBS: usize> ConditionallySelectable for MontyForm<LIMBS> {
315    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
316        Self {
317            montgomery_form: Uint::conditional_select(
318                &a.montgomery_form,
319                &b.montgomery_form,
320                choice,
321            ),
322            params: MontyParams::conditional_select(&a.params, &b.params, choice),
323        }
324    }
325}
326
327impl<const LIMBS: usize> ConstantTimeEq for MontyForm<LIMBS> {
328    fn ct_eq(&self, other: &Self) -> Choice {
329        self.montgomery_form.ct_eq(&other.montgomery_form) & self.params.ct_eq(&other.params)
330    }
331}
332
333#[cfg(feature = "zeroize")]
334impl<const LIMBS: usize> zeroize::Zeroize for MontyForm<LIMBS> {
335    fn zeroize(&mut self) {
336        self.montgomery_form.zeroize();
337        self.params.zeroize();
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::{Limb, MontyParams, Odd, Uint};
344
345    #[test]
346    fn new_params_with_valid_modulus() {
347        let modulus = Odd::new(Uint::from(3u8)).unwrap();
348        let params = MontyParams::<1>::new(modulus);
349
350        assert_eq!(params.mod_leading_zeros, Limb::BITS - 2);
351    }
352}