sp_arithmetic/
helpers_128bit.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// Some code is based upon Derek Dreery's IntegerSquareRoot impl, used under license.
5// SPDX-License-Identifier: Apache-2.0
6
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License at
10//
11// 	http://www.apache.org/licenses/LICENSE-2.0
12//
13// Unless required by applicable law or agreed to in writing, software
14// distributed under the License is distributed on an "AS IS" BASIS,
15// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16// See the License for the specific language governing permissions and
17// limitations under the License.
18
19//! Some helper functions to work with 128bit numbers. Note that the functionality provided here is
20//! only sensible to use with 128bit numbers because for smaller sizes, you can always rely on
21//! assumptions of a bigger type (u128) being available, or simply create a per-thing and use the
22//! multiplication implementation provided there.
23
24use crate::{biguint, Rounding};
25use core::cmp::{max, min};
26
27/// Helper gcd function used in Rational128 implementation.
28pub fn gcd(a: u128, b: u128) -> u128 {
29	match ((a, b), (a & 1, b & 1)) {
30		((x, y), _) if x == y => y,
31		((0, x), _) | ((x, 0), _) => x,
32		((x, y), (0, 1)) | ((y, x), (1, 0)) => gcd(x >> 1, y),
33		((x, y), (0, 0)) => gcd(x >> 1, y >> 1) << 1,
34		((x, y), (1, 1)) => {
35			let (x, y) = (min(x, y), max(x, y));
36			gcd((y - x) >> 1, x)
37		},
38		_ => unreachable!(),
39	}
40}
41
42/// split a u128 into two u64 limbs
43pub fn split(a: u128) -> (u64, u64) {
44	let al = a as u64;
45	let ah = (a >> 64) as u64;
46	(ah, al)
47}
48
49/// Convert a u128 to a u32 based biguint.
50pub fn to_big_uint(x: u128) -> biguint::BigUint {
51	let (xh, xl) = split(x);
52	let (xhh, xhl) = biguint::split(xh);
53	let (xlh, xll) = biguint::split(xl);
54	let mut n = biguint::BigUint::from_limbs(&[xhh, xhl, xlh, xll]);
55	n.lstrip();
56	n
57}
58
59mod double128 {
60	// Inspired by: https://medium.com/wicketh/mathemagic-512-bit-division-in-solidity-afa55870a65
61
62	/// Returns the least significant 64 bits of a
63	const fn low_64(a: u128) -> u128 {
64		a & ((1 << 64) - 1)
65	}
66
67	/// Returns the most significant 64 bits of a
68	const fn high_64(a: u128) -> u128 {
69		a >> 64
70	}
71
72	/// Returns 2^128 - a (two's complement)
73	const fn neg128(a: u128) -> u128 {
74		(!a).wrapping_add(1)
75	}
76
77	/// Returns 2^128 / a
78	const fn div128(a: u128) -> u128 {
79		(neg128(a) / a).wrapping_add(1)
80	}
81
82	/// Returns 2^128 % a
83	const fn mod128(a: u128) -> u128 {
84		neg128(a) % a
85	}
86
87	#[derive(Copy, Clone, Eq, PartialEq)]
88	pub struct Double128 {
89		high: u128,
90		low: u128,
91	}
92
93	impl Double128 {
94		pub const fn try_into_u128(self) -> Result<u128, ()> {
95			match self.high {
96				0 => Ok(self.low),
97				_ => Err(()),
98			}
99		}
100
101		pub const fn zero() -> Self {
102			Self { high: 0, low: 0 }
103		}
104
105		/// Return a `Double128` value representing the `scaled_value << 64`.
106		///
107		/// This means the lower half of the `high` component will be equal to the upper 64-bits of
108		/// `scaled_value` (in the lower positions) and the upper half of the `low` component will
109		/// be equal to the lower 64-bits of `scaled_value`.
110		pub const fn left_shift_64(scaled_value: u128) -> Self {
111			Self { high: scaled_value >> 64, low: scaled_value << 64 }
112		}
113
114		/// Construct a value from the upper 128 bits only, with the lower being zeroed.
115		pub const fn from_low(low: u128) -> Self {
116			Self { high: 0, low }
117		}
118
119		/// Returns the same value ignoring anything in the high 128-bits.
120		pub const fn low_part(self) -> Self {
121			Self { high: 0, ..self }
122		}
123
124		/// Returns a*b (in 256 bits)
125		pub const fn product_of(a: u128, b: u128) -> Self {
126			// Split a and b into hi and lo 64-bit parts
127			let (a_low, a_high) = (low_64(a), high_64(a));
128			let (b_low, b_high) = (low_64(b), high_64(b));
129			// a = (a_low + a_high << 64); b = (b_low + b_high << 64);
130			// ergo a*b = (a_low + a_high << 64)(b_low + b_high << 64)
131			//          = a_low * b_low
132			//          + a_low * b_high << 64
133			//          + a_high << 64 * b_low
134			//          + a_high << 64 * b_high << 64
135			// assuming:
136			//        f = a_low * b_low
137			//        o = a_low * b_high
138			//        i = a_high * b_low
139			//        l = a_high * b_high
140			// then:
141			//      a*b = (o+i) << 64 + f + l << 128
142			let (f, o, i, l) = (a_low * b_low, a_low * b_high, a_high * b_low, a_high * b_high);
143			let fl = Self { high: l, low: f };
144			let i = Self::left_shift_64(i);
145			let o = Self::left_shift_64(o);
146			fl.add(i).add(o)
147		}
148
149		pub const fn add(self, b: Self) -> Self {
150			let (low, overflow) = self.low.overflowing_add(b.low);
151			let carry = overflow as u128; // 1 if true, 0 if false.
152			let high = self.high.wrapping_add(b.high).wrapping_add(carry as u128);
153			Double128 { high, low }
154		}
155
156		pub const fn div(mut self, rhs: u128) -> (Self, u128) {
157			if rhs == 1 {
158				return (self, 0)
159			}
160
161			// (self === a; rhs === b)
162			// Calculate a / b
163			// = (a_high << 128 + a_low) / b
164			//   let (q, r) = (div128(b), mod128(b));
165			// = (a_low * (q * b + r)) + a_high) / b
166			// = (a_low * q * b + a_low * r + a_high)/b
167			// = (a_low * r + a_high) / b + a_low * q
168			let (q, r) = (div128(rhs), mod128(rhs));
169
170			// x = current result
171			// a = next number
172			let mut x = Self::zero();
173			while self.high != 0 {
174				// x += a.low * q
175				x = x.add(Self::product_of(self.high, q));
176				// a = a.low * r + a.high
177				self = Self::product_of(self.high, r).add(self.low_part());
178			}
179
180			(x.add(Self::from_low(self.low / rhs)), self.low % rhs)
181		}
182	}
183}
184
185/// Returns `a * b / c` (wrapping to 128 bits) or `None` in the case of
186/// overflow.
187pub const fn multiply_by_rational_with_rounding(
188	a: u128,
189	b: u128,
190	c: u128,
191	r: Rounding,
192) -> Option<u128> {
193	use double128::Double128;
194	if c == 0 {
195		return None
196	}
197	let (result, remainder) = Double128::product_of(a, b).div(c);
198	let mut result: u128 = match result.try_into_u128() {
199		Ok(v) => v,
200		Err(_) => return None,
201	};
202	if match r {
203		Rounding::Up => remainder > 0,
204		// cannot be `(c + 1) / 2` since `c` might be `max_value` and overflow.
205		Rounding::NearestPrefUp => remainder >= c / 2 + c % 2,
206		Rounding::NearestPrefDown => remainder > c / 2,
207		Rounding::Down => false,
208	} {
209		result = match result.checked_add(1) {
210			Some(v) => v,
211			None => return None,
212		};
213	}
214	Some(result)
215}
216
217pub const fn sqrt(mut n: u128) -> u128 {
218	// Modified from https://github.com/derekdreery/integer-sqrt-rs (Apache/MIT).
219	if n == 0 {
220		return 0
221	}
222
223	// Compute bit, the largest power of 4 <= n
224	let max_shift: u32 = 0u128.leading_zeros() - 1;
225	let shift: u32 = (max_shift - n.leading_zeros()) & !1;
226	let mut bit = 1u128 << shift;
227
228	// Algorithm based on the implementation in:
229	// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)
230	// Note that result/bit are logically unsigned (even if T is signed).
231	let mut result = 0u128;
232	while bit != 0 {
233		if n >= result + bit {
234			n -= result + bit;
235			result = (result >> 1) + bit;
236		} else {
237			result = result >> 1;
238		}
239		bit = bit >> 2;
240	}
241	result
242}
243
244#[cfg(test)]
245mod tests {
246	use super::*;
247	use codec::{Decode, Encode};
248	use multiply_by_rational_with_rounding as mulrat;
249	use Rounding::*;
250
251	const MAX: u128 = u128::max_value();
252
253	#[test]
254	fn rational_multiply_basic_rounding_works() {
255		assert_eq!(mulrat(1, 1, 1, Up), Some(1));
256		assert_eq!(mulrat(3, 1, 3, Up), Some(1));
257		assert_eq!(mulrat(1, 1, 3, Up), Some(1));
258		assert_eq!(mulrat(1, 2, 3, Down), Some(0));
259		assert_eq!(mulrat(1, 1, 3, NearestPrefDown), Some(0));
260		assert_eq!(mulrat(1, 1, 2, NearestPrefDown), Some(0));
261		assert_eq!(mulrat(1, 2, 3, NearestPrefDown), Some(1));
262		assert_eq!(mulrat(1, 1, 3, NearestPrefUp), Some(0));
263		assert_eq!(mulrat(1, 1, 2, NearestPrefUp), Some(1));
264		assert_eq!(mulrat(1, 2, 3, NearestPrefUp), Some(1));
265	}
266
267	#[test]
268	fn rational_multiply_big_number_works() {
269		assert_eq!(mulrat(MAX, MAX - 1, MAX, Down), Some(MAX - 1));
270		assert_eq!(mulrat(MAX, 1, MAX, Down), Some(1));
271		assert_eq!(mulrat(MAX, MAX - 1, MAX, Up), Some(MAX - 1));
272		assert_eq!(mulrat(MAX, 1, MAX, Up), Some(1));
273		assert_eq!(mulrat(1, MAX - 1, MAX, Down), Some(0));
274		assert_eq!(mulrat(1, 1, MAX, Up), Some(1));
275		assert_eq!(mulrat(1, MAX / 2, MAX, NearestPrefDown), Some(0));
276		assert_eq!(mulrat(1, MAX / 2 + 1, MAX, NearestPrefDown), Some(1));
277		assert_eq!(mulrat(1, MAX / 2, MAX, NearestPrefUp), Some(0));
278		assert_eq!(mulrat(1, MAX / 2 + 1, MAX, NearestPrefUp), Some(1));
279	}
280
281	#[test]
282	fn sqrt_works() {
283		for i in 0..100_000u32 {
284			let a = sqrt(random_u128(i));
285			assert_eq!(sqrt(a * a), a);
286		}
287	}
288
289	fn random_u128(seed: u32) -> u128 {
290		u128::decode(&mut &seed.using_encoded(sp_crypto_hashing::twox_128)[..]).unwrap_or(0)
291	}
292
293	#[test]
294	fn op_checked_rounded_div_works() {
295		for i in 0..100_000u32 {
296			let a = random_u128(i);
297			let b = random_u128(i + (1 << 30));
298			let c = random_u128(i + (1 << 31));
299			let x = mulrat(a, b, c, NearestPrefDown);
300			let y = multiply_by_rational_with_rounding(a, b, c, Rounding::NearestPrefDown);
301			assert_eq!(x.is_some(), y.is_some());
302			let x = x.unwrap_or(0);
303			let y = y.unwrap_or(0);
304			let d = x.max(y) - x.min(y);
305			assert_eq!(d, 0);
306		}
307	}
308}