crypto_bigint/uint/boxed/
sqrt.rs

1//! [`BoxedUint`] square root operations.
2
3use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, CtOption};
4
5use crate::{BitOps, BoxedUint, ConstantTimeSelect, NonZero, SquareRoot};
6
7impl BoxedUint {
8    /// Computes √(`self`) in constant time.
9    ///
10    /// Callers can check if `self` is a square by squaring the result
11    pub fn sqrt(&self) -> Self {
12        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
13        //
14        // See Hast, "Note on computation of integer square roots"
15        // for the proof of the sufficiency of the bound on iterations.
16        // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
17
18        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
19        // Will not overflow since `b <= BITS`.
20        let (mut x, _overflow) =
21            Self::one_with_precision(self.bits_precision()).overflowing_shl((self.bits() + 1) >> 1); // ≥ √(`self`)
22
23        // Repeat enough times to guarantee result has stabilized.
24        let mut i = 0;
25        let mut x_prev = x.clone(); // keep the previous iteration in case we need to roll back.
26        let mut nz_x = NonZero(x.clone());
27
28        // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
29        while i < self.log2_bits() + 2 {
30            x_prev.limbs.clone_from_slice(&x.limbs);
31
32            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
33            let x_nonzero = x.is_nonzero();
34            let mut j = 0;
35            while j < nz_x.0.limbs.len() {
36                nz_x.0.limbs[j].conditional_assign(&x.limbs[j], x_nonzero);
37                j += 1;
38            }
39            let (q, _) = self.div_rem(&nz_x);
40            x.conditional_adc_assign(&q, x_nonzero);
41            x.shr1_assign();
42            i += 1;
43        }
44
45        // At this point `x_prev == x_{n}` and `x == x_{n+1}`
46        // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
47        // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
48        Self::ct_select(&x_prev, &x, Self::ct_gt(&x_prev, &x))
49    }
50
51    /// Computes √(`self`)
52    ///
53    /// Callers can check if `self` is a square by squaring the result
54    pub fn sqrt_vartime(&self) -> Self {
55        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
56
57        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
58        // Will not overflow since `b <= BITS`.
59        let (mut x, _overflow) =
60            Self::one_with_precision(self.bits_precision()).overflowing_shl((self.bits() + 1) >> 1); // ≥ √(`self`)
61
62        // Stop right away if `x` is zero to avoid divizion by zero.
63        while !x
64            .cmp_vartime(&Self::zero_with_precision(self.bits_precision()))
65            .is_eq()
66        {
67            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
68            let q =
69                self.wrapping_div_vartime(&NonZero::<Self>::new(x.clone()).expect("Division by 0"));
70            let t = x.wrapping_add(&q);
71            let next_x = t.shr1();
72
73            // If `next_x` is the same as `x` or greater, we reached convergence
74            // (`x` is guaranteed to either go down or oscillate between
75            // `sqrt(self)` and `sqrt(self) + 1`)
76            if !x.cmp_vartime(&next_x).is_gt() {
77                break;
78            }
79
80            x = next_x;
81        }
82
83        if self.is_nonzero().into() {
84            x
85        } else {
86            Self::zero_with_precision(self.bits_precision())
87        }
88    }
89
90    /// Wrapped sqrt is just normal √(`self`)
91    /// There’s no way wrapping could ever happen.
92    /// This function exists so that all operations are accounted for in the wrapping operations.
93    pub fn wrapping_sqrt(&self) -> Self {
94        self.sqrt()
95    }
96
97    /// Wrapped sqrt is just normal √(`self`)
98    /// There’s no way wrapping could ever happen.
99    /// This function exists so that all operations are accounted for in the wrapping operations.
100    pub fn wrapping_sqrt_vartime(&self) -> Self {
101        self.sqrt_vartime()
102    }
103
104    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
105    /// only if the √(`self`)² == self
106    pub fn checked_sqrt(&self) -> CtOption<Self> {
107        let r = self.sqrt();
108        let s = r.wrapping_mul(&r);
109        CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
110    }
111
112    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
113    /// only if the √(`self`)² == self
114    pub fn checked_sqrt_vartime(&self) -> CtOption<Self> {
115        let r = self.sqrt_vartime();
116        let s = r.wrapping_mul(&r);
117        CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
118    }
119}
120
121impl SquareRoot for BoxedUint {
122    fn sqrt(&self) -> Self {
123        self.sqrt()
124    }
125
126    fn sqrt_vartime(&self) -> Self {
127        self.sqrt_vartime()
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use crate::{BoxedUint, Limb};
134
135    #[cfg(feature = "rand")]
136    use {
137        crate::{CheckedMul, RandomBits},
138        rand_chacha::ChaChaRng,
139        rand_core::{RngCore, SeedableRng},
140    };
141
142    #[test]
143    fn edge() {
144        assert_eq!(
145            BoxedUint::zero_with_precision(256).sqrt(),
146            BoxedUint::zero_with_precision(256)
147        );
148        assert_eq!(
149            BoxedUint::one_with_precision(256).sqrt(),
150            BoxedUint::one_with_precision(256)
151        );
152        let mut half = BoxedUint::zero_with_precision(256);
153        for i in 0..half.limbs.len() / 2 {
154            half.limbs[i] = Limb::MAX;
155        }
156        let u256_max = !BoxedUint::zero_with_precision(256);
157        assert_eq!(u256_max.sqrt(), half);
158
159        // Test edge cases that use up the maximum number of iterations.
160
161        // `x = (r + 1)^2 - 583`, where `r` is the expected square root.
162        assert_eq!(
163            BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
164                .unwrap()
165                .sqrt(),
166            BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
167                .unwrap(),
168        );
169        assert_eq!(
170            BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
171                .unwrap()
172                .sqrt_vartime(),
173            BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
174                .unwrap()
175        );
176
177        // `x = (r + 1)^2 - 205`, where `r` is the expected square root.
178        assert_eq!(
179            BoxedUint::from_be_hex(
180                "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
181                256
182            )
183            .unwrap()
184            .sqrt(),
185            BoxedUint::from_be_hex(
186                "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
187                256
188            )
189            .unwrap()
190        );
191        assert_eq!(
192            BoxedUint::from_be_hex(
193                "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
194                256
195            )
196            .unwrap()
197            .sqrt_vartime(),
198            BoxedUint::from_be_hex(
199                "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
200                256
201            )
202            .unwrap()
203        );
204    }
205
206    #[test]
207    fn edge_vartime() {
208        assert_eq!(
209            BoxedUint::zero_with_precision(256).sqrt_vartime(),
210            BoxedUint::zero_with_precision(256)
211        );
212        assert_eq!(
213            BoxedUint::one_with_precision(256).sqrt_vartime(),
214            BoxedUint::one_with_precision(256)
215        );
216        let mut half = BoxedUint::zero_with_precision(256);
217        for i in 0..half.limbs.len() / 2 {
218            half.limbs[i] = Limb::MAX;
219        }
220        let u256_max = !BoxedUint::zero_with_precision(256);
221        assert_eq!(u256_max.sqrt_vartime(), half);
222    }
223
224    #[test]
225    fn simple() {
226        let tests = [
227            (4u8, 2u8),
228            (9, 3),
229            (16, 4),
230            (25, 5),
231            (36, 6),
232            (49, 7),
233            (64, 8),
234            (81, 9),
235            (100, 10),
236            (121, 11),
237            (144, 12),
238            (169, 13),
239        ];
240        for (a, e) in &tests {
241            let l = BoxedUint::from(*a);
242            let r = BoxedUint::from(*e);
243            assert_eq!(l.sqrt(), r);
244            assert_eq!(l.sqrt_vartime(), r);
245            assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
246            assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
247        }
248    }
249
250    #[test]
251    fn nonsquares() {
252        assert_eq!(BoxedUint::from(2u8).sqrt(), BoxedUint::from(1u8));
253        assert_eq!(BoxedUint::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
254        assert_eq!(BoxedUint::from(3u8).sqrt(), BoxedUint::from(1u8));
255        assert_eq!(BoxedUint::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
256        assert_eq!(BoxedUint::from(5u8).sqrt(), BoxedUint::from(2u8));
257        assert_eq!(BoxedUint::from(6u8).sqrt(), BoxedUint::from(2u8));
258        assert_eq!(BoxedUint::from(7u8).sqrt(), BoxedUint::from(2u8));
259        assert_eq!(BoxedUint::from(8u8).sqrt(), BoxedUint::from(2u8));
260        assert_eq!(BoxedUint::from(10u8).sqrt(), BoxedUint::from(3u8));
261    }
262
263    #[test]
264    fn nonsquares_vartime() {
265        assert_eq!(BoxedUint::from(2u8).sqrt_vartime(), BoxedUint::from(1u8));
266        assert_eq!(
267            BoxedUint::from(2u8)
268                .checked_sqrt_vartime()
269                .is_some()
270                .unwrap_u8(),
271            0
272        );
273        assert_eq!(BoxedUint::from(3u8).sqrt_vartime(), BoxedUint::from(1u8));
274        assert_eq!(
275            BoxedUint::from(3u8)
276                .checked_sqrt_vartime()
277                .is_some()
278                .unwrap_u8(),
279            0
280        );
281        assert_eq!(BoxedUint::from(5u8).sqrt_vartime(), BoxedUint::from(2u8));
282        assert_eq!(BoxedUint::from(6u8).sqrt_vartime(), BoxedUint::from(2u8));
283        assert_eq!(BoxedUint::from(7u8).sqrt_vartime(), BoxedUint::from(2u8));
284        assert_eq!(BoxedUint::from(8u8).sqrt_vartime(), BoxedUint::from(2u8));
285        assert_eq!(BoxedUint::from(10u8).sqrt_vartime(), BoxedUint::from(3u8));
286    }
287
288    #[cfg(feature = "rand")]
289    #[test]
290    fn fuzz() {
291        let mut rng = ChaChaRng::from_seed([7u8; 32]);
292        for _ in 0..50 {
293            let t = rng.next_u32() as u64;
294            let s = BoxedUint::from(t);
295            let s2 = s.checked_mul(&s).unwrap();
296            assert_eq!(s2.sqrt(), s);
297            assert_eq!(s2.sqrt_vartime(), s);
298            assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
299            assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
300        }
301
302        for _ in 0..50 {
303            let s = BoxedUint::random_bits(&mut rng, 512);
304            let mut s2 = BoxedUint::zero_with_precision(512);
305            s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
306            assert_eq!(s.square().sqrt(), s2);
307            assert_eq!(s.square().sqrt_vartime(), s2);
308        }
309    }
310}