1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
use super::UInt;
use crate::{Limb, LimbUInt};
use subtle::{ConstantTimeEq, CtOption};
impl<const LIMBS: usize> UInt<LIMBS> {
pub const fn sqrt(&self) -> Self {
let max_bits = (self.bits() + 1) >> 1;
let cap = Self::ONE.shl_vartime(max_bits);
let mut guess = cap;
let mut xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
while guess.ct_cmp(&xn) == -1 {
let res = Limb::ct_cmp(Limb(xn.bits() as LimbUInt), Limb(max_bits as LimbUInt)) - 1;
let le = Limb::is_nonzero(Limb(res as LimbUInt));
guess = Self::ct_select(cap, xn, le);
xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}
while guess.ct_cmp(&xn) == 1 && xn.ct_is_nonzero() == LimbUInt::MAX {
guess = xn;
xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}
Self::ct_select(Self::ZERO, guess, self.ct_is_nonzero())
}
pub const fn wrapping_sqrt(&self) -> Self {
self.sqrt()
}
pub fn checked_sqrt(&self) -> CtOption<Self> {
let r = self.sqrt();
let s = r.wrapping_mul(&r);
CtOption::new(r, self.ct_eq(&s))
}
}
#[cfg(test)]
mod tests {
use crate::{Limb, U256};
#[cfg(feature = "rand")]
use {
crate::{CheckedMul, Random, U512},
rand_chacha::ChaChaRng,
rand_core::{RngCore, SeedableRng},
};
#[test]
fn edge() {
assert_eq!(U256::ZERO.sqrt(), U256::ZERO);
assert_eq!(U256::ONE.sqrt(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.sqrt(), half,);
}
#[test]
fn simple() {
let tests = [
(4u8, 2u8),
(9, 3),
(16, 4),
(25, 5),
(36, 6),
(49, 7),
(64, 8),
(81, 9),
(100, 10),
(121, 11),
(144, 12),
(169, 13),
];
for (a, e) in &tests {
let l = U256::from(*a);
let r = U256::from(*e);
assert_eq!(l.sqrt(), r);
assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
}
}
#[test]
fn nonsquares() {
assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8));
}
#[cfg(feature = "rand")]
#[test]
fn fuzz() {
let mut rng = ChaChaRng::from_seed([7u8; 32]);
for _ in 0..50 {
let t = rng.next_u32() as u64;
let s = U256::from(t);
let s2 = s.checked_mul(&s).unwrap();
assert_eq!(s2.sqrt(), s);
assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
}
for _ in 0..50 {
let s = U256::random(&mut rng);
let mut s2 = U512::ZERO;
s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
assert_eq!(s.square().sqrt(), s2);
}
}
}