crypto_bigint/uint/
inv_mod.rs

1use super::Uint;
2use crate::{
3    modular::SafeGcdInverter, ConstChoice, ConstCtOption, InvMod, Odd, PrecomputeInverter,
4};
5use subtle::CtOption;
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8    /// Computes 1/`self` mod `2^k`.
9    /// This method is constant-time w.r.t. `self` but not `k`.
10    ///
11    /// If the inverse does not exist (`k > 0` and `self` is even),
12    /// returns `ConstChoice::FALSE` as the second element of the tuple,
13    /// otherwise returns `ConstChoice::TRUE`.
14    pub const fn inv_mod2k_vartime(&self, k: u32) -> ConstCtOption<Self> {
15        // Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k"
16        // by Sadiel de la Fe and Carles Ferrer.
17        // See <https://www.mdpi.com/2410-387X/2/3/23>.
18
19        // Note that we are not using Alrgorithm 4, since we have a different approach
20        // of enforcing constant-timeness w.r.t. `self`.
21
22        let mut x = Self::ZERO; // keeps `x` during iterations
23        let mut b = Self::ONE; // keeps `b_i` during iterations
24        let mut i = 0;
25
26        // The inverse exists either if `k` is 0 or if `self` is odd.
27        let is_some = ConstChoice::from_u32_nonzero(k).not().or(self.is_odd());
28
29        while i < k {
30            // X_i = b_i mod 2
31            let x_i = b.limbs[0].0 & 1;
32            let x_i_choice = ConstChoice::from_word_lsb(x_i);
33            // b_{i+1} = (b_i - a * X_i) / 2
34            b = Self::select(&b, &b.wrapping_sub(self), x_i_choice).shr1();
35            // Store the X_i bit in the result (x = x | (1 << X_i))
36            let shifted = Uint::from_word(x_i)
37                .overflowing_shl_vartime(i)
38                .expect("shift within range");
39            x = x.bitor(&shifted);
40
41            i += 1;
42        }
43
44        ConstCtOption::new(x, is_some)
45    }
46
47    /// Computes 1/`self` mod `2^k`.
48    ///
49    /// If the inverse does not exist (`k > 0` and `self` is even),
50    /// returns `ConstChoice::FALSE` as the second element of the tuple,
51    /// otherwise returns `ConstChoice::TRUE`.
52    pub const fn inv_mod2k(&self, k: u32) -> ConstCtOption<Self> {
53        // This is the same algorithm as in `inv_mod2k_vartime()`,
54        // but made constant-time w.r.t `k` as well.
55
56        let mut x = Self::ZERO; // keeps `x` during iterations
57        let mut b = Self::ONE; // keeps `b_i` during iterations
58        let mut i = 0;
59
60        // The inverse exists either if `k` is 0 or if `self` is odd.
61        let is_some = ConstChoice::from_u32_nonzero(k).not().or(self.is_odd());
62
63        while i < Self::BITS {
64            // Only iterations for i = 0..k need to change `x`,
65            // the rest are dummy ones performed for the sake of constant-timeness.
66            let within_range = ConstChoice::from_u32_lt(i, k);
67
68            // X_i = b_i mod 2
69            let x_i = b.limbs[0].0 & 1;
70            let x_i_choice = ConstChoice::from_word_lsb(x_i);
71            // b_{i+1} = (b_i - self * X_i) / 2
72            b = Self::select(&b, &b.wrapping_sub(self), x_i_choice).shr1();
73
74            // Store the X_i bit in the result (x = x | (1 << X_i))
75            // Don't change the result in dummy iterations.
76            let x_i_choice = x_i_choice.and(within_range);
77            x = x.set_bit(i, x_i_choice);
78
79            i += 1;
80        }
81
82        ConstCtOption::new(x, is_some)
83    }
84}
85
86impl<const LIMBS: usize, const UNSAT_LIMBS: usize> Uint<LIMBS>
87where
88    Odd<Self>: PrecomputeInverter<Inverter = SafeGcdInverter<LIMBS, UNSAT_LIMBS>>,
89{
90    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
91    pub const fn inv_odd_mod(&self, modulus: &Odd<Self>) -> ConstCtOption<Self> {
92        SafeGcdInverter::<LIMBS, UNSAT_LIMBS>::new(modulus, &Uint::ONE).inv(self)
93    }
94
95    /// Computes the multiplicative inverse of `self` mod `modulus`.
96    ///
97    /// Returns some if an inverse exists, otherwise none.
98    pub const fn inv_mod(&self, modulus: &Self) -> ConstCtOption<Self> {
99        // Decompose `modulus = s * 2^k` where `s` is odd
100        let k = modulus.trailing_zeros();
101        let s = modulus.overflowing_shr(k).unwrap_or(Self::ZERO);
102
103        // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
104        // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
105        let s_is_odd = s.is_odd();
106        let maybe_a = self.inv_odd_mod(&Odd(s)).and_choice(s_is_odd);
107
108        let maybe_b = self.inv_mod2k(k);
109        let is_some = maybe_a.is_some().and(maybe_b.is_some());
110
111        // Unwrap to avoid mapping through ConstCtOptions.
112        // if `a` or `b` don't exist, the returned ConstCtOption will be None anyway.
113        let a = maybe_a.unwrap_or(Uint::ZERO);
114        let b = maybe_b.unwrap_or(Uint::ZERO);
115
116        // Restore from RNS:
117        // self^{-1} = a mod s = b mod 2^k
118        // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
119        // (essentially one step of the Garner's algorithm for recovery from RNS).
120
121        // `s` is odd, so this always exists
122        let m_odd_inv = s.inv_mod2k(k).expect("inverse mod 2^k exists");
123
124        // This part is mod 2^k
125        let shifted = Uint::ONE.overflowing_shl(k).unwrap_or(Self::ZERO);
126        let mask = shifted.wrapping_sub(&Uint::ONE);
127        let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask);
128
129        // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
130        // so `a + s * t <= s * 2^k - 1 == modulus - 1`.
131        let result = a.wrapping_add(&s.wrapping_mul(&t));
132        ConstCtOption::new(result, is_some)
133    }
134}
135
136impl<const LIMBS: usize, const UNSAT_LIMBS: usize> InvMod for Uint<LIMBS>
137where
138    Odd<Self>: PrecomputeInverter<Inverter = SafeGcdInverter<LIMBS, UNSAT_LIMBS>>,
139{
140    type Output = Self;
141
142    fn inv_mod(&self, modulus: &Self) -> CtOption<Self> {
143        self.inv_mod(modulus).into()
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use crate::{U1024, U256, U64};
150
151    #[test]
152    fn inv_mod2k() {
153        let v =
154            U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f");
155        let e =
156            U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf");
157        let a = v.inv_mod2k(256).unwrap();
158        assert_eq!(e, a);
159
160        let a = v.inv_mod2k_vartime(256).unwrap();
161        assert_eq!(e, a);
162
163        let v =
164            U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
165        let e =
166            U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1");
167        let a = v.inv_mod2k(256).unwrap();
168        assert_eq!(e, a);
169
170        let a = v.inv_mod2k_vartime(256).unwrap();
171        assert_eq!(e, a);
172
173        // Check that even if the number is >= 2^k, the inverse is still correct.
174
175        let v =
176            U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
177        let e =
178            U256::from_be_hex("0000000000000000000000000000000000000000034613dbb4f20099aa774ec1");
179        let a = v.inv_mod2k(90).unwrap();
180        assert_eq!(e, a);
181
182        let a = v.inv_mod2k_vartime(90).unwrap();
183        assert_eq!(e, a);
184
185        // An inverse of an even number does not exist.
186
187        let a = U256::from(10u64).inv_mod2k(4);
188        assert!(a.is_none().is_true_vartime());
189
190        let a = U256::from(10u64).inv_mod2k_vartime(4);
191        assert!(a.is_none().is_true_vartime());
192
193        // A degenerate case. An inverse mod 2^0 == 1 always exists even for even numbers.
194
195        let a = U256::from(10u64).inv_mod2k_vartime(0).unwrap();
196        assert_eq!(a, U256::ZERO);
197    }
198
199    #[test]
200    fn test_invert_odd() {
201        let a = U1024::from_be_hex(concat![
202            "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
203            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
204            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
205            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
206        ]);
207        let m = U1024::from_be_hex(concat![
208            "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
209            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
210            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
211            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
212        ])
213        .to_odd()
214        .unwrap();
215        let expected = U1024::from_be_hex(concat![
216            "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
217            "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
218            "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
219            "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
220        ]);
221
222        let res = a.inv_odd_mod(&m).unwrap();
223        assert_eq!(res, expected);
224
225        // Even though it is less efficient, it still works
226        let res = a.inv_mod(&m).unwrap();
227        assert_eq!(res, expected);
228    }
229
230    #[test]
231    fn test_invert_odd_no_inverse() {
232        // 2^128 - 159, a prime
233        let p1 =
234            U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff61");
235        // 2^128 - 173, a prime
236        let p2 =
237            U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff53");
238
239        let m = p1.wrapping_mul(&p2).to_odd().unwrap();
240
241        // `m` is a multiple of `p1`, so no inverse exists
242        let res = p1.inv_odd_mod(&m);
243        assert!(res.is_none().is_true_vartime());
244    }
245
246    #[test]
247    fn test_invert_even() {
248        let a = U1024::from_be_hex(concat![
249            "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
250            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
251            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
252            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
253        ]);
254        let m = U1024::from_be_hex(concat![
255            "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
256            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
257            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
258            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
259        ]);
260        let expected = U1024::from_be_hex(concat![
261            "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
262            "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
263            "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
264            "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
265        ]);
266
267        let res = a.inv_mod(&m).unwrap();
268        assert_eq!(res, expected);
269    }
270
271    #[test]
272    fn test_invert_small() {
273        let a = U64::from(3u64);
274        let m = U64::from(13u64).to_odd().unwrap();
275
276        let res = a.inv_odd_mod(&m).unwrap();
277        assert_eq!(U64::from(9u64), res);
278    }
279
280    #[test]
281    fn test_no_inverse_small() {
282        let a = U64::from(14u64);
283        let m = U64::from(49u64).to_odd().unwrap();
284
285        let res = a.inv_odd_mod(&m);
286        assert!(res.is_none().is_true_vartime());
287    }
288}