crypto_bigint/uint/boxed/
sqrt.rs1use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, CtOption};
4
5use crate::{BitOps, BoxedUint, ConstantTimeSelect, NonZero, SquareRoot};
6
7impl BoxedUint {
8 pub fn sqrt(&self) -> Self {
12 let (mut x, _overflow) =
21 Self::one_with_precision(self.bits_precision()).overflowing_shl((self.bits() + 1) >> 1); let mut i = 0;
25 let mut x_prev = x.clone(); let mut nz_x = NonZero(x.clone());
27
28 while i < self.log2_bits() + 2 {
30 x_prev.limbs.clone_from_slice(&x.limbs);
31
32 let x_nonzero = x.is_nonzero();
34 let mut j = 0;
35 while j < nz_x.0.limbs.len() {
36 nz_x.0.limbs[j].conditional_assign(&x.limbs[j], x_nonzero);
37 j += 1;
38 }
39 let (q, _) = self.div_rem(&nz_x);
40 x.conditional_adc_assign(&q, x_nonzero);
41 x.shr1_assign();
42 i += 1;
43 }
44
45 Self::ct_select(&x_prev, &x, Self::ct_gt(&x_prev, &x))
49 }
50
51 pub fn sqrt_vartime(&self) -> Self {
55 let (mut x, _overflow) =
60 Self::one_with_precision(self.bits_precision()).overflowing_shl((self.bits() + 1) >> 1); while !x
64 .cmp_vartime(&Self::zero_with_precision(self.bits_precision()))
65 .is_eq()
66 {
67 let q =
69 self.wrapping_div_vartime(&NonZero::<Self>::new(x.clone()).expect("Division by 0"));
70 let t = x.wrapping_add(&q);
71 let next_x = t.shr1();
72
73 if !x.cmp_vartime(&next_x).is_gt() {
77 break;
78 }
79
80 x = next_x;
81 }
82
83 if self.is_nonzero().into() {
84 x
85 } else {
86 Self::zero_with_precision(self.bits_precision())
87 }
88 }
89
90 pub fn wrapping_sqrt(&self) -> Self {
94 self.sqrt()
95 }
96
97 pub fn wrapping_sqrt_vartime(&self) -> Self {
101 self.sqrt_vartime()
102 }
103
104 pub fn checked_sqrt(&self) -> CtOption<Self> {
107 let r = self.sqrt();
108 let s = r.wrapping_mul(&r);
109 CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
110 }
111
112 pub fn checked_sqrt_vartime(&self) -> CtOption<Self> {
115 let r = self.sqrt_vartime();
116 let s = r.wrapping_mul(&r);
117 CtOption::new(r, ConstantTimeEq::ct_eq(self, &s))
118 }
119}
120
121impl SquareRoot for BoxedUint {
122 fn sqrt(&self) -> Self {
123 self.sqrt()
124 }
125
126 fn sqrt_vartime(&self) -> Self {
127 self.sqrt_vartime()
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use crate::{BoxedUint, Limb};
134
135 #[cfg(feature = "rand")]
136 use {
137 crate::{CheckedMul, RandomBits},
138 rand_chacha::ChaChaRng,
139 rand_core::{RngCore, SeedableRng},
140 };
141
142 #[test]
143 fn edge() {
144 assert_eq!(
145 BoxedUint::zero_with_precision(256).sqrt(),
146 BoxedUint::zero_with_precision(256)
147 );
148 assert_eq!(
149 BoxedUint::one_with_precision(256).sqrt(),
150 BoxedUint::one_with_precision(256)
151 );
152 let mut half = BoxedUint::zero_with_precision(256);
153 for i in 0..half.limbs.len() / 2 {
154 half.limbs[i] = Limb::MAX;
155 }
156 let u256_max = !BoxedUint::zero_with_precision(256);
157 assert_eq!(u256_max.sqrt(), half);
158
159 assert_eq!(
163 BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
164 .unwrap()
165 .sqrt(),
166 BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
167 .unwrap(),
168 );
169 assert_eq!(
170 BoxedUint::from_be_hex("055fa39422bd9f281762946e056535badbf8a6864d45fa3d", 192)
171 .unwrap()
172 .sqrt_vartime(),
173 BoxedUint::from_be_hex("0000000000000000000000002516f0832a538b2d98869e21", 192)
174 .unwrap()
175 );
176
177 assert_eq!(
179 BoxedUint::from_be_hex(
180 "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
181 256
182 )
183 .unwrap()
184 .sqrt(),
185 BoxedUint::from_be_hex(
186 "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
187 256
188 )
189 .unwrap()
190 );
191 assert_eq!(
192 BoxedUint::from_be_hex(
193 "4bb750738e25a8f82940737d94a48a91f8cd918a3679ff90c1a631f2bd6c3597",
194 256
195 )
196 .unwrap()
197 .sqrt_vartime(),
198 BoxedUint::from_be_hex(
199 "000000000000000000000000000000008b3956339e8315cff66eb6107b610075",
200 256
201 )
202 .unwrap()
203 );
204 }
205
206 #[test]
207 fn edge_vartime() {
208 assert_eq!(
209 BoxedUint::zero_with_precision(256).sqrt_vartime(),
210 BoxedUint::zero_with_precision(256)
211 );
212 assert_eq!(
213 BoxedUint::one_with_precision(256).sqrt_vartime(),
214 BoxedUint::one_with_precision(256)
215 );
216 let mut half = BoxedUint::zero_with_precision(256);
217 for i in 0..half.limbs.len() / 2 {
218 half.limbs[i] = Limb::MAX;
219 }
220 let u256_max = !BoxedUint::zero_with_precision(256);
221 assert_eq!(u256_max.sqrt_vartime(), half);
222 }
223
224 #[test]
225 fn simple() {
226 let tests = [
227 (4u8, 2u8),
228 (9, 3),
229 (16, 4),
230 (25, 5),
231 (36, 6),
232 (49, 7),
233 (64, 8),
234 (81, 9),
235 (100, 10),
236 (121, 11),
237 (144, 12),
238 (169, 13),
239 ];
240 for (a, e) in &tests {
241 let l = BoxedUint::from(*a);
242 let r = BoxedUint::from(*e);
243 assert_eq!(l.sqrt(), r);
244 assert_eq!(l.sqrt_vartime(), r);
245 assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
246 assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8);
247 }
248 }
249
250 #[test]
251 fn nonsquares() {
252 assert_eq!(BoxedUint::from(2u8).sqrt(), BoxedUint::from(1u8));
253 assert_eq!(BoxedUint::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
254 assert_eq!(BoxedUint::from(3u8).sqrt(), BoxedUint::from(1u8));
255 assert_eq!(BoxedUint::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
256 assert_eq!(BoxedUint::from(5u8).sqrt(), BoxedUint::from(2u8));
257 assert_eq!(BoxedUint::from(6u8).sqrt(), BoxedUint::from(2u8));
258 assert_eq!(BoxedUint::from(7u8).sqrt(), BoxedUint::from(2u8));
259 assert_eq!(BoxedUint::from(8u8).sqrt(), BoxedUint::from(2u8));
260 assert_eq!(BoxedUint::from(10u8).sqrt(), BoxedUint::from(3u8));
261 }
262
263 #[test]
264 fn nonsquares_vartime() {
265 assert_eq!(BoxedUint::from(2u8).sqrt_vartime(), BoxedUint::from(1u8));
266 assert_eq!(
267 BoxedUint::from(2u8)
268 .checked_sqrt_vartime()
269 .is_some()
270 .unwrap_u8(),
271 0
272 );
273 assert_eq!(BoxedUint::from(3u8).sqrt_vartime(), BoxedUint::from(1u8));
274 assert_eq!(
275 BoxedUint::from(3u8)
276 .checked_sqrt_vartime()
277 .is_some()
278 .unwrap_u8(),
279 0
280 );
281 assert_eq!(BoxedUint::from(5u8).sqrt_vartime(), BoxedUint::from(2u8));
282 assert_eq!(BoxedUint::from(6u8).sqrt_vartime(), BoxedUint::from(2u8));
283 assert_eq!(BoxedUint::from(7u8).sqrt_vartime(), BoxedUint::from(2u8));
284 assert_eq!(BoxedUint::from(8u8).sqrt_vartime(), BoxedUint::from(2u8));
285 assert_eq!(BoxedUint::from(10u8).sqrt_vartime(), BoxedUint::from(3u8));
286 }
287
288 #[cfg(feature = "rand")]
289 #[test]
290 fn fuzz() {
291 let mut rng = ChaChaRng::from_seed([7u8; 32]);
292 for _ in 0..50 {
293 let t = rng.next_u32() as u64;
294 let s = BoxedUint::from(t);
295 let s2 = s.checked_mul(&s).unwrap();
296 assert_eq!(s2.sqrt(), s);
297 assert_eq!(s2.sqrt_vartime(), s);
298 assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
299 assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1);
300 }
301
302 for _ in 0..50 {
303 let s = BoxedUint::random_bits(&mut rng, 512);
304 let mut s2 = BoxedUint::zero_with_precision(512);
305 s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
306 assert_eq!(s.square().sqrt(), s2);
307 assert_eq!(s.square().sqrt_vartime(), s2);
308 }
309 }
310}