crypto_bigint/uint/
sqrt.rs

1//! [`Uint`] square root operations.
2
3use crate::{NonZero, SquareRoot, Uint};
4use subtle::{ConstantTimeEq, CtOption};
5
6impl<const LIMBS: usize> Uint<LIMBS> {
7    /// Computes √(`self`) in constant time.
8    ///
9    /// Callers can check if `self` is a square by squaring the result
10    pub const fn sqrt(&self) -> Self {
11        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
12        //
13        // See Hast, "Note on computation of integer square roots"
14        // for the proof of the sufficiency of the bound on iterations.
15        // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf
16
17        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
18        // Will not overflow since `b <= BITS`.
19        let mut x = Self::ONE
20            .overflowing_shl((self.bits() + 1) >> 1)
21            .expect("shift within range"); // ≥ √(`self`)
22
23        // Repeat enough times to guarantee result has stabilized.
24        let mut i = 0;
25        let mut x_prev = x; // keep the previous iteration in case we need to roll back.
26
27        // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
28        while i < Self::LOG2_BITS + 2 {
29            x_prev = x;
30
31            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
32            let x_nonzero = x.is_nonzero();
33            let (q, _) = self.div_rem(&NonZero(Self::select(&Self::ONE, &x, x_nonzero)));
34            x = Self::select(&Self::ZERO, &x.wrapping_add(&q).shr1(), x_nonzero);
35            i += 1;
36        }
37
38        // At this point `x_prev == x_{n}` and `x == x_{n+1}`
39        // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
40        // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
41        Self::select(&x_prev, &x, Uint::gt(&x_prev, &x))
42    }
43
44    /// Computes √(`self`)
45    ///
46    /// Callers can check if `self` is a square by squaring the result
47    pub const fn sqrt_vartime(&self) -> Self {
48        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
49
50        if self.cmp_vartime(&Self::ZERO).is_eq() {
51            return Self::ZERO;
52        }
53
54        // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
55        // Will not overflow since `b <= BITS`.
56        let mut x = Self::ONE
57            .overflowing_shl((self.bits() + 1) >> 1)
58            .expect("shift within range"); // ≥ √(`self`)
59
60        // Stop right away if `x` is zero to avoid divizion by zero.
61        while !x.cmp_vartime(&Self::ZERO).is_eq() {
62            // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
63            let q = self.wrapping_div_vartime(&x.to_nz().expect("ensured non-zero"));
64            let t = x.wrapping_add(&q);
65            let next_x = t.shr1();
66
67            // If `next_x` is the same as `x` or greater, we reached convergence
68            // (`x` is guaranteed to either go down or oscillate between
69            // `sqrt(self)` and `sqrt(self) + 1`)
70            if !x.cmp_vartime(&next_x).is_gt() {
71                break;
72            }
73
74            x = next_x;
75        }
76
77        x
78    }
79
80    /// Wrapped sqrt is just normal √(`self`)
81    /// There’s no way wrapping could ever happen.
82    /// This function exists so that all operations are accounted for in the wrapping operations.
83    pub const fn wrapping_sqrt(&self) -> Self {
84        self.sqrt()
85    }
86
87    /// Wrapped sqrt is just normal √(`self`)
88    /// There’s no way wrapping could ever happen.
89    /// This function exists so that all operations are accounted for in the wrapping operations.
90    pub const fn wrapping_sqrt_vartime(&self) -> Self {
91        self.sqrt_vartime()
92    }
93
94    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
95    /// only if the √(`self`)² == self
96    pub fn checked_sqrt(&self) -> CtOption<Self> {
97        let r = self.sqrt();
98        let s = r.wrapping_mul(&r);
99        CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
100    }
101
102    /// Perform checked sqrt, returning a [`CtOption`] which `is_some`
103    /// only if the √(`self`)² == self
104    pub fn checked_sqrt_vartime(&self) -> CtOption<Self> {
105        let r = self.sqrt_vartime();
106        let s = r.wrapping_mul(&r);
107        CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
108    }
109}
110
111impl<const LIMBS: usize> SquareRoot for Uint<LIMBS> {
112    fn sqrt(&self) -> Self {
113        self.sqrt()
114    }
115
116    fn sqrt_vartime(&self) -> Self {
117        self.sqrt_vartime()
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use crate::{Limb, U192, U256};
124
125    #[cfg(feature = "rand")]
126    use {
127        crate::{CheckedMul, Random, U512},
128        rand_chacha::ChaChaRng,
129        rand_core::{RngCore, SeedableRng},
130    };
131
132    #[test]
133    fn edge() {
134        assert_eq!(U256::ZERO.sqrt(), U256::ZERO);
135        assert_eq!(U256::ONE.sqrt(), U256::ONE);
136        let mut half = U256::ZERO;
137        for i in 0..half.limbs.len() / 2 {
138            half.limbs[i] = Limb::MAX;
139        }
140        assert_eq!(U256::MAX.sqrt(), half);
141
142        // Test edge cases that use up the maximum number of iterations.
143
144        // `x = (r + 1)^2 - 583`, where `r` is the expected square root.
145        assert_eq!(
146            U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt(),
147            U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
148        );
149        assert_eq!(
150            U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt_vartime(),
151            U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
152        );
153
154        // `x = (r + 1)^2 - 205`, where `r` is the expected square root.
155        assert_eq!(
156            U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
157                .sqrt(),
158            U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
159        );
160        assert_eq!(
161            U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
162                .sqrt_vartime(),
163            U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
164        );
165    }
166
167    #[test]
168    fn edge_vartime() {
169        assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO);
170        assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE);
171        let mut half = U256::ZERO;
172        for i in 0..half.limbs.len() / 2 {
173            half.limbs[i] = Limb::MAX;
174        }
175        assert_eq!(U256::MAX.sqrt_vartime(), half);
176    }
177
178    #[test]
179    fn simple() {
180        let tests = [
181            (4u8, 2u8),
182            (9, 3),
183            (16, 4),
184            (25, 5),
185            (36, 6),
186            (49, 7),
187            (64, 8),
188            (81, 9),
189            (100, 10),
190            (121, 11),
191            (144, 12),
192            (169, 13),
193        ];
194        for (a, e) in &tests {
195            let l = U256::from(*a);
196            let r = U256::from(*e);
197            assert_eq!(l.sqrt(), r);
198            assert_eq!(l.sqrt_vartime(), r);
199            assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
200            assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
201        }
202    }
203
204    #[test]
205    fn nonsquares() {
206        assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8));
207        assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
208        assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8));
209        assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
210        assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8));
211        assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8));
212        assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8));
213        assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8));
214        assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8));
215    }
216
217    #[test]
218    fn nonsquares_vartime() {
219        assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8));
220        assert_eq!(
221            U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(),
222            0
223        );
224        assert_eq!(U256::from(3u8).sqrt_vartime(), U256::from(1u8));
225        assert_eq!(
226            U256::from(3u8).checked_sqrt_vartime().is_some().unwrap_u8(),
227            0
228        );
229        assert_eq!(U256::from(5u8).sqrt_vartime(), U256::from(2u8));
230        assert_eq!(U256::from(6u8).sqrt_vartime(), U256::from(2u8));
231        assert_eq!(U256::from(7u8).sqrt_vartime(), U256::from(2u8));
232        assert_eq!(U256::from(8u8).sqrt_vartime(), U256::from(2u8));
233        assert_eq!(U256::from(10u8).sqrt_vartime(), U256::from(3u8));
234    }
235
236    #[cfg(feature = "rand")]
237    #[test]
238    fn fuzz() {
239        let mut rng = ChaChaRng::from_seed([7u8; 32]);
240        for _ in 0..50 {
241            let t = rng.next_u32() as u64;
242            let s = U256::from(t);
243            let s2 = s.checked_mul(&s).unwrap();
244            assert_eq!(s2.sqrt(), s);
245            assert_eq!(s2.sqrt_vartime(), s);
246            assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
247            assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
248        }
249
250        for _ in 0..50 {
251            let s = U256::random(&mut rng);
252            let mut s2 = U512::ZERO;
253            s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
254            assert_eq!(s.square().sqrt(), s2);
255            assert_eq!(s.square().sqrt_vartime(), s2);
256        }
257    }
258}