1use crate::{NonZero, SquareRoot, Uint};
4use subtle::{ConstantTimeEq, CtOption};
5
6impl<const LIMBS: usize> Uint<LIMBS> {
7 pub const fn sqrt(&self) -> Self {
11 let mut x = Self::ONE
20 .overflowing_shl((self.bits() + 1) >> 1)
21 .expect("shift within range"); let mut i = 0;
25 let mut x_prev = x; while i < Self::LOG2_BITS + 2 {
29 x_prev = x;
30
31 let x_nonzero = x.is_nonzero();
33 let (q, _) = self.div_rem(&NonZero(Self::select(&Self::ONE, &x, x_nonzero)));
34 x = Self::select(&Self::ZERO, &x.wrapping_add(&q).shr1(), x_nonzero);
35 i += 1;
36 }
37
38 Self::select(&x_prev, &x, Uint::gt(&x_prev, &x))
42 }
43
44 pub const fn sqrt_vartime(&self) -> Self {
48 if self.cmp_vartime(&Self::ZERO).is_eq() {
51 return Self::ZERO;
52 }
53
54 let mut x = Self::ONE
57 .overflowing_shl((self.bits() + 1) >> 1)
58 .expect("shift within range"); while !x.cmp_vartime(&Self::ZERO).is_eq() {
62 let q = self.wrapping_div_vartime(&x.to_nz().expect("ensured non-zero"));
64 let t = x.wrapping_add(&q);
65 let next_x = t.shr1();
66
67 if !x.cmp_vartime(&next_x).is_gt() {
71 break;
72 }
73
74 x = next_x;
75 }
76
77 x
78 }
79
80 pub const fn wrapping_sqrt(&self) -> Self {
84 self.sqrt()
85 }
86
87 pub const fn wrapping_sqrt_vartime(&self) -> Self {
91 self.sqrt_vartime()
92 }
93
94 pub fn checked_sqrt(&self) -> CtOption<Self> {
97 let r = self.sqrt();
98 let s = r.wrapping_mul(&r);
99 CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
100 }
101
102 pub fn checked_sqrt_vartime(&self) -> CtOption<Self> {
105 let r = self.sqrt_vartime();
106 let s = r.wrapping_mul(&r);
107 CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
108 }
109}
110
111impl<const LIMBS: usize> SquareRoot for Uint<LIMBS> {
112 fn sqrt(&self) -> Self {
113 self.sqrt()
114 }
115
116 fn sqrt_vartime(&self) -> Self {
117 self.sqrt_vartime()
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use crate::{Limb, U192, U256};
124
125 #[cfg(feature = "rand")]
126 use {
127 crate::{CheckedMul, Random, U512},
128 rand_chacha::ChaChaRng,
129 rand_core::{RngCore, SeedableRng},
130 };
131
132 #[test]
133 fn edge() {
134 assert_eq!(U256::ZERO.sqrt(), U256::ZERO);
135 assert_eq!(U256::ONE.sqrt(), U256::ONE);
136 let mut half = U256::ZERO;
137 for i in 0..half.limbs.len() / 2 {
138 half.limbs[i] = Limb::MAX;
139 }
140 assert_eq!(U256::MAX.sqrt(), half);
141
142 assert_eq!(
146 U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt(),
147 U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
148 );
149 assert_eq!(
150 U192::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d").sqrt_vartime(),
151 U192::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21")
152 );
153
154 assert_eq!(
156 U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
157 .sqrt(),
158 U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
159 );
160 assert_eq!(
161 U256::from_be_hex("4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597")
162 .sqrt_vartime(),
163 U256::from_be_hex("000000000000000000000000000000008b3956339e8315cff66eb6107b610075")
164 );
165 }
166
167 #[test]
168 fn edge_vartime() {
169 assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO);
170 assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE);
171 let mut half = U256::ZERO;
172 for i in 0..half.limbs.len() / 2 {
173 half.limbs[i] = Limb::MAX;
174 }
175 assert_eq!(U256::MAX.sqrt_vartime(), half);
176 }
177
178 #[test]
179 fn simple() {
180 let tests = [
181 (4u8, 2u8),
182 (9, 3),
183 (16, 4),
184 (25, 5),
185 (36, 6),
186 (49, 7),
187 (64, 8),
188 (81, 9),
189 (100, 10),
190 (121, 11),
191 (144, 12),
192 (169, 13),
193 ];
194 for (a, e) in &tests {
195 let l = U256::from(*a);
196 let r = U256::from(*e);
197 assert_eq!(l.sqrt(), r);
198 assert_eq!(l.sqrt_vartime(), r);
199 assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
200 assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
201 }
202 }
203
204 #[test]
205 fn nonsquares() {
206 assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8));
207 assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
208 assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8));
209 assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
210 assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8));
211 assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8));
212 assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8));
213 assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8));
214 assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8));
215 }
216
217 #[test]
218 fn nonsquares_vartime() {
219 assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8));
220 assert_eq!(
221 U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(),
222 0
223 );
224 assert_eq!(U256::from(3u8).sqrt_vartime(), U256::from(1u8));
225 assert_eq!(
226 U256::from(3u8).checked_sqrt_vartime().is_some().unwrap_u8(),
227 0
228 );
229 assert_eq!(U256::from(5u8).sqrt_vartime(), U256::from(2u8));
230 assert_eq!(U256::from(6u8).sqrt_vartime(), U256::from(2u8));
231 assert_eq!(U256::from(7u8).sqrt_vartime(), U256::from(2u8));
232 assert_eq!(U256::from(8u8).sqrt_vartime(), U256::from(2u8));
233 assert_eq!(U256::from(10u8).sqrt_vartime(), U256::from(3u8));
234 }
235
236 #[cfg(feature = "rand")]
237 #[test]
238 fn fuzz() {
239 let mut rng = ChaChaRng::from_seed([7u8; 32]);
240 for _ in 0..50 {
241 let t = rng.next_u32() as u64;
242 let s = U256::from(t);
243 let s2 = s.checked_mul(&s).unwrap();
244 assert_eq!(s2.sqrt(), s);
245 assert_eq!(s2.sqrt_vartime(), s);
246 assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
247 assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
248 }
249
250 for _ in 0..50 {
251 let s = U256::random(&mut rng);
252 let mut s2 = U512::ZERO;
253 s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
254 assert_eq!(s.square().sqrt(), s2);
255 assert_eq!(s.square().sqrt_vartime(), s2);
256 }
257 }
258}