1use crate::{biguint, Rounding};
25use core::cmp::{max, min};
26
27pub fn gcd(a: u128, b: u128) -> u128 {
29 match ((a, b), (a & 1, b & 1)) {
30 ((x, y), _) if x == y => y,
31 ((0, x), _) | ((x, 0), _) => x,
32 ((x, y), (0, 1)) | ((y, x), (1, 0)) => gcd(x >> 1, y),
33 ((x, y), (0, 0)) => gcd(x >> 1, y >> 1) << 1,
34 ((x, y), (1, 1)) => {
35 let (x, y) = (min(x, y), max(x, y));
36 gcd((y - x) >> 1, x)
37 },
38 _ => unreachable!(),
39 }
40}
41
42pub fn split(a: u128) -> (u64, u64) {
44 let al = a as u64;
45 let ah = (a >> 64) as u64;
46 (ah, al)
47}
48
49pub fn to_big_uint(x: u128) -> biguint::BigUint {
51 let (xh, xl) = split(x);
52 let (xhh, xhl) = biguint::split(xh);
53 let (xlh, xll) = biguint::split(xl);
54 let mut n = biguint::BigUint::from_limbs(&[xhh, xhl, xlh, xll]);
55 n.lstrip();
56 n
57}
58
59mod double128 {
60 const fn low_64(a: u128) -> u128 {
64 a & ((1 << 64) - 1)
65 }
66
67 const fn high_64(a: u128) -> u128 {
69 a >> 64
70 }
71
72 const fn neg128(a: u128) -> u128 {
74 (!a).wrapping_add(1)
75 }
76
77 const fn div128(a: u128) -> u128 {
79 (neg128(a) / a).wrapping_add(1)
80 }
81
82 const fn mod128(a: u128) -> u128 {
84 neg128(a) % a
85 }
86
87 #[derive(Copy, Clone, Eq, PartialEq)]
88 pub struct Double128 {
89 high: u128,
90 low: u128,
91 }
92
93 impl Double128 {
94 pub const fn try_into_u128(self) -> Result<u128, ()> {
95 match self.high {
96 0 => Ok(self.low),
97 _ => Err(()),
98 }
99 }
100
101 pub const fn zero() -> Self {
102 Self { high: 0, low: 0 }
103 }
104
105 pub const fn left_shift_64(scaled_value: u128) -> Self {
111 Self { high: scaled_value >> 64, low: scaled_value << 64 }
112 }
113
114 pub const fn from_low(low: u128) -> Self {
116 Self { high: 0, low }
117 }
118
119 pub const fn low_part(self) -> Self {
121 Self { high: 0, ..self }
122 }
123
124 pub const fn product_of(a: u128, b: u128) -> Self {
126 let (a_low, a_high) = (low_64(a), high_64(a));
128 let (b_low, b_high) = (low_64(b), high_64(b));
129 let (f, o, i, l) = (a_low * b_low, a_low * b_high, a_high * b_low, a_high * b_high);
143 let fl = Self { high: l, low: f };
144 let i = Self::left_shift_64(i);
145 let o = Self::left_shift_64(o);
146 fl.add(i).add(o)
147 }
148
149 pub const fn add(self, b: Self) -> Self {
150 let (low, overflow) = self.low.overflowing_add(b.low);
151 let carry = overflow as u128; let high = self.high.wrapping_add(b.high).wrapping_add(carry as u128);
153 Double128 { high, low }
154 }
155
156 pub const fn div(mut self, rhs: u128) -> (Self, u128) {
157 if rhs == 1 {
158 return (self, 0)
159 }
160
161 let (q, r) = (div128(rhs), mod128(rhs));
169
170 let mut x = Self::zero();
173 while self.high != 0 {
174 x = x.add(Self::product_of(self.high, q));
176 self = Self::product_of(self.high, r).add(self.low_part());
178 }
179
180 (x.add(Self::from_low(self.low / rhs)), self.low % rhs)
181 }
182 }
183}
184
185pub const fn multiply_by_rational_with_rounding(
188 a: u128,
189 b: u128,
190 c: u128,
191 r: Rounding,
192) -> Option<u128> {
193 use double128::Double128;
194 if c == 0 {
195 return None
196 }
197 let (result, remainder) = Double128::product_of(a, b).div(c);
198 let mut result: u128 = match result.try_into_u128() {
199 Ok(v) => v,
200 Err(_) => return None,
201 };
202 if match r {
203 Rounding::Up => remainder > 0,
204 Rounding::NearestPrefUp => remainder >= c / 2 + c % 2,
206 Rounding::NearestPrefDown => remainder > c / 2,
207 Rounding::Down => false,
208 } {
209 result = match result.checked_add(1) {
210 Some(v) => v,
211 None => return None,
212 };
213 }
214 Some(result)
215}
216
217pub const fn sqrt(mut n: u128) -> u128 {
218 if n == 0 {
220 return 0
221 }
222
223 let max_shift: u32 = 0u128.leading_zeros() - 1;
225 let shift: u32 = (max_shift - n.leading_zeros()) & !1;
226 let mut bit = 1u128 << shift;
227
228 let mut result = 0u128;
232 while bit != 0 {
233 if n >= result + bit {
234 n -= result + bit;
235 result = (result >> 1) + bit;
236 } else {
237 result = result >> 1;
238 }
239 bit = bit >> 2;
240 }
241 result
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use codec::{Decode, Encode};
248 use multiply_by_rational_with_rounding as mulrat;
249 use Rounding::*;
250
251 const MAX: u128 = u128::max_value();
252
253 #[test]
254 fn rational_multiply_basic_rounding_works() {
255 assert_eq!(mulrat(1, 1, 1, Up), Some(1));
256 assert_eq!(mulrat(3, 1, 3, Up), Some(1));
257 assert_eq!(mulrat(1, 1, 3, Up), Some(1));
258 assert_eq!(mulrat(1, 2, 3, Down), Some(0));
259 assert_eq!(mulrat(1, 1, 3, NearestPrefDown), Some(0));
260 assert_eq!(mulrat(1, 1, 2, NearestPrefDown), Some(0));
261 assert_eq!(mulrat(1, 2, 3, NearestPrefDown), Some(1));
262 assert_eq!(mulrat(1, 1, 3, NearestPrefUp), Some(0));
263 assert_eq!(mulrat(1, 1, 2, NearestPrefUp), Some(1));
264 assert_eq!(mulrat(1, 2, 3, NearestPrefUp), Some(1));
265 }
266
267 #[test]
268 fn rational_multiply_big_number_works() {
269 assert_eq!(mulrat(MAX, MAX - 1, MAX, Down), Some(MAX - 1));
270 assert_eq!(mulrat(MAX, 1, MAX, Down), Some(1));
271 assert_eq!(mulrat(MAX, MAX - 1, MAX, Up), Some(MAX - 1));
272 assert_eq!(mulrat(MAX, 1, MAX, Up), Some(1));
273 assert_eq!(mulrat(1, MAX - 1, MAX, Down), Some(0));
274 assert_eq!(mulrat(1, 1, MAX, Up), Some(1));
275 assert_eq!(mulrat(1, MAX / 2, MAX, NearestPrefDown), Some(0));
276 assert_eq!(mulrat(1, MAX / 2 + 1, MAX, NearestPrefDown), Some(1));
277 assert_eq!(mulrat(1, MAX / 2, MAX, NearestPrefUp), Some(0));
278 assert_eq!(mulrat(1, MAX / 2 + 1, MAX, NearestPrefUp), Some(1));
279 }
280
281 #[test]
282 fn sqrt_works() {
283 for i in 0..100_000u32 {
284 let a = sqrt(random_u128(i));
285 assert_eq!(sqrt(a * a), a);
286 }
287 }
288
289 fn random_u128(seed: u32) -> u128 {
290 u128::decode(&mut &seed.using_encoded(sp_crypto_hashing::twox_128)[..]).unwrap_or(0)
291 }
292
293 #[test]
294 fn op_checked_rounded_div_works() {
295 for i in 0..100_000u32 {
296 let a = random_u128(i);
297 let b = random_u128(i + (1 << 30));
298 let c = random_u128(i + (1 << 31));
299 let x = mulrat(a, b, c, NearestPrefDown);
300 let y = multiply_by_rational_with_rounding(a, b, c, Rounding::NearestPrefDown);
301 assert_eq!(x.is_some(), y.is_some());
302 let x = x.unwrap_or(0);
303 let y = y.unwrap_or(0);
304 let d = x.max(y) - x.min(y);
305 assert_eq!(d, 0);
306 }
307 }
308}