sp_arithmetic/
rational.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18use crate::{biguint::BigUint, helpers_128bit, Rounding};
19use core::cmp::Ordering;
20use num_traits::{Bounded, One, Zero};
21
22/// A wrapper for any rational number with infinitely large numerator and denominator.
23///
24/// This type exists to facilitate `cmp` operation
25/// on values like `a/b < c/d` where `a, b, c, d` are all `BigUint`.
26#[derive(Clone, Default, Eq)]
27pub struct RationalInfinite(BigUint, BigUint);
28
29impl RationalInfinite {
30	/// Return the numerator reference.
31	pub fn n(&self) -> &BigUint {
32		&self.0
33	}
34
35	/// Return the denominator reference.
36	pub fn d(&self) -> &BigUint {
37		&self.1
38	}
39
40	/// Build from a raw `n/d`.
41	pub fn from(n: BigUint, d: BigUint) -> Self {
42		Self(n, d.max(BigUint::one()))
43	}
44
45	/// Zero.
46	pub fn zero() -> Self {
47		Self(BigUint::zero(), BigUint::one())
48	}
49
50	/// One.
51	pub fn one() -> Self {
52		Self(BigUint::one(), BigUint::one())
53	}
54}
55
56impl PartialOrd for RationalInfinite {
57	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
58		Some(self.cmp(other))
59	}
60}
61
62impl Ord for RationalInfinite {
63	fn cmp(&self, other: &Self) -> Ordering {
64		// handle some edge cases.
65		if self.d() == other.d() {
66			self.n().cmp(other.n())
67		} else if self.d().is_zero() {
68			Ordering::Greater
69		} else if other.d().is_zero() {
70			Ordering::Less
71		} else {
72			// (a/b) cmp (c/d) => (a*d) cmp (c*b)
73			self.n().clone().mul(other.d()).cmp(&other.n().clone().mul(self.d()))
74		}
75	}
76}
77
78impl PartialEq for RationalInfinite {
79	fn eq(&self, other: &Self) -> bool {
80		self.cmp(other) == Ordering::Equal
81	}
82}
83
84impl From<Rational128> for RationalInfinite {
85	fn from(t: Rational128) -> Self {
86		Self(t.0.into(), t.1.into())
87	}
88}
89
90/// A wrapper for any rational number with a 128 bit numerator and denominator.
91#[derive(Clone, Copy, Default, Eq)]
92pub struct Rational128(u128, u128);
93
94#[cfg(feature = "std")]
95impl core::fmt::Debug for Rational128 {
96	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
97		write!(f, "Rational128({} / {} ≈ {:.8})", self.0, self.1, self.0 as f64 / self.1 as f64)
98	}
99}
100
101#[cfg(not(feature = "std"))]
102impl core::fmt::Debug for Rational128 {
103	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
104		write!(f, "Rational128({} / {})", self.0, self.1)
105	}
106}
107
108impl Rational128 {
109	/// Zero.
110	pub fn zero() -> Self {
111		Self(0, 1)
112	}
113
114	/// One
115	pub fn one() -> Self {
116		Self(1, 1)
117	}
118
119	/// If it is zero or not
120	pub fn is_zero(&self) -> bool {
121		self.0.is_zero()
122	}
123
124	/// Build from a raw `n/d`.
125	pub fn from(n: u128, d: u128) -> Self {
126		Self(n, d.max(1))
127	}
128
129	/// Build from a raw `n/d`. This could lead to / 0 if not properly handled.
130	pub fn from_unchecked(n: u128, d: u128) -> Self {
131		Self(n, d)
132	}
133
134	/// Return the numerator.
135	pub fn n(&self) -> u128 {
136		self.0
137	}
138
139	/// Return the denominator.
140	pub fn d(&self) -> u128 {
141		self.1
142	}
143
144	/// Convert `self` to a similar rational number where denominator is the given `den`.
145	//
146	/// This only returns if the result is accurate. `None` is returned if the result cannot be
147	/// accurately calculated.
148	pub fn to_den(self, den: u128) -> Option<Self> {
149		if den == self.1 {
150			Some(self)
151		} else {
152			helpers_128bit::multiply_by_rational_with_rounding(
153				self.0,
154				den,
155				self.1,
156				Rounding::NearestPrefDown,
157			)
158			.map(|n| Self(n, den))
159		}
160	}
161
162	/// Get the least common divisor of `self` and `other`.
163	///
164	/// This only returns if the result is accurate. `None` is returned if the result cannot be
165	/// accurately calculated.
166	pub fn lcm(&self, other: &Self) -> Option<u128> {
167		// this should be tested better: two large numbers that are almost the same.
168		if self.1 == other.1 {
169			return Some(self.1)
170		}
171		let g = helpers_128bit::gcd(self.1, other.1);
172		helpers_128bit::multiply_by_rational_with_rounding(
173			self.1,
174			other.1,
175			g,
176			Rounding::NearestPrefDown,
177		)
178	}
179
180	/// A saturating add that assumes `self` and `other` have the same denominator.
181	pub fn lazy_saturating_add(self, other: Self) -> Self {
182		if other.is_zero() {
183			self
184		} else {
185			Self(self.0.saturating_add(other.0), self.1)
186		}
187	}
188
189	/// A saturating subtraction that assumes `self` and `other` have the same denominator.
190	pub fn lazy_saturating_sub(self, other: Self) -> Self {
191		if other.is_zero() {
192			self
193		} else {
194			Self(self.0.saturating_sub(other.0), self.1)
195		}
196	}
197
198	/// Addition. Simply tries to unify the denominators and add the numerators.
199	///
200	/// Overflow might happen during any of the steps. Error is returned in such cases.
201	pub fn checked_add(self, other: Self) -> Result<Self, &'static str> {
202		let lcm = self.lcm(&other).ok_or(0).map_err(|_| "failed to scale to denominator")?;
203		let self_scaled =
204			self.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
205		let other_scaled =
206			other.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
207		let n = self_scaled
208			.0
209			.checked_add(other_scaled.0)
210			.ok_or("overflow while adding numerators")?;
211		Ok(Self(n, self_scaled.1))
212	}
213
214	/// Subtraction. Simply tries to unify the denominators and subtract the numerators.
215	///
216	/// Overflow might happen during any of the steps. None is returned in such cases.
217	pub fn checked_sub(self, other: Self) -> Result<Self, &'static str> {
218		let lcm = self.lcm(&other).ok_or(0).map_err(|_| "failed to scale to denominator")?;
219		let self_scaled =
220			self.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
221		let other_scaled =
222			other.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
223
224		let n = self_scaled
225			.0
226			.checked_sub(other_scaled.0)
227			.ok_or("overflow while subtracting numerators")?;
228		Ok(Self(n, self_scaled.1))
229	}
230}
231
232impl Bounded for Rational128 {
233	fn min_value() -> Self {
234		Self(0, 1)
235	}
236
237	fn max_value() -> Self {
238		Self(Bounded::max_value(), 1)
239	}
240}
241
242impl<T: Into<u128>> From<T> for Rational128 {
243	fn from(t: T) -> Self {
244		Self::from(t.into(), 1)
245	}
246}
247
248impl PartialOrd for Rational128 {
249	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
250		Some(self.cmp(other))
251	}
252}
253
254impl Ord for Rational128 {
255	fn cmp(&self, other: &Self) -> Ordering {
256		// handle some edge cases.
257		if self.1 == other.1 {
258			self.0.cmp(&other.0)
259		} else if self.1.is_zero() {
260			Ordering::Greater
261		} else if other.1.is_zero() {
262			Ordering::Less
263		} else {
264			// Don't even compute gcd.
265			let self_n = helpers_128bit::to_big_uint(self.0) * helpers_128bit::to_big_uint(other.1);
266			let other_n =
267				helpers_128bit::to_big_uint(other.0) * helpers_128bit::to_big_uint(self.1);
268			self_n.cmp(&other_n)
269		}
270	}
271}
272
273impl PartialEq for Rational128 {
274	fn eq(&self, other: &Self) -> bool {
275		// handle some edge cases.
276		if self.1 == other.1 {
277			self.0.eq(&other.0)
278		} else {
279			let self_n = helpers_128bit::to_big_uint(self.0) * helpers_128bit::to_big_uint(other.1);
280			let other_n =
281				helpers_128bit::to_big_uint(other.0) * helpers_128bit::to_big_uint(self.1);
282			self_n.eq(&other_n)
283		}
284	}
285}
286
287pub trait MultiplyRational: Sized {
288	fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self>;
289}
290
291macro_rules! impl_rrm {
292	($ulow:ty, $uhi:ty) => {
293		impl MultiplyRational for $ulow {
294			fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self> {
295				if d.is_zero() {
296					return None
297				}
298
299				let sn = (self as $uhi) * (n as $uhi);
300				let mut result = sn / (d as $uhi);
301				let remainder = (sn % (d as $uhi)) as $ulow;
302				if match r {
303					Rounding::Up => remainder > 0,
304					// cannot be `(d + 1) / 2` since `d` might be `max_value` and overflow.
305					Rounding::NearestPrefUp => remainder >= d / 2 + d % 2,
306					Rounding::NearestPrefDown => remainder > d / 2,
307					Rounding::Down => false,
308				} {
309					result = match result.checked_add(1) {
310						Some(v) => v,
311						None => return None,
312					};
313				}
314				if result > (<$ulow>::max_value() as $uhi) {
315					None
316				} else {
317					Some(result as $ulow)
318				}
319			}
320		}
321	};
322}
323
324impl_rrm!(u8, u16);
325impl_rrm!(u16, u32);
326impl_rrm!(u32, u64);
327impl_rrm!(u64, u128);
328
329impl MultiplyRational for u128 {
330	fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self> {
331		crate::helpers_128bit::multiply_by_rational_with_rounding(self, n, d, r)
332	}
333}
334
335#[cfg(test)]
336mod tests {
337	use super::{helpers_128bit::*, *};
338	use static_assertions::const_assert;
339
340	const MAX128: u128 = u128::MAX;
341	const MAX64: u128 = u64::MAX as u128;
342	const MAX64_2: u128 = 2 * u64::MAX as u128;
343
344	fn r(p: u128, q: u128) -> Rational128 {
345		Rational128(p, q)
346	}
347
348	fn mul_div(a: u128, b: u128, c: u128) -> u128 {
349		use primitive_types::U256;
350		if a.is_zero() {
351			return Zero::zero()
352		}
353		let c = c.max(1);
354
355		// e for extended
356		let ae: U256 = a.into();
357		let be: U256 = b.into();
358		let ce: U256 = c.into();
359
360		let r = ae * be / ce;
361		if r > u128::max_value().into() {
362			a
363		} else {
364			r.as_u128()
365		}
366	}
367
368	#[test]
369	fn truth_value_function_works() {
370		assert_eq!(mul_div(2u128.pow(100), 8, 4), 2u128.pow(101));
371		assert_eq!(mul_div(2u128.pow(100), 4, 8), 2u128.pow(99));
372
373		// and it returns a if result cannot fit
374		assert_eq!(mul_div(MAX128 - 10, 2, 1), MAX128 - 10);
375	}
376
377	#[test]
378	fn to_denom_works() {
379		// simple up and down
380		assert_eq!(r(1, 5).to_den(10), Some(r(2, 10)));
381		assert_eq!(r(4, 10).to_den(5), Some(r(2, 5)));
382
383		// up and down with large numbers
384		assert_eq!(r(MAX128 - 10, MAX128).to_den(10), Some(r(10, 10)));
385		assert_eq!(r(MAX128 / 2, MAX128).to_den(10), Some(r(5, 10)));
386
387		// large to perbill. This is very well needed for npos-elections.
388		assert_eq!(r(MAX128 / 2, MAX128).to_den(1000_000_000), Some(r(500_000_000, 1000_000_000)));
389
390		// large to large
391		assert_eq!(r(MAX128 / 2, MAX128).to_den(MAX128 / 2), Some(r(MAX128 / 4, MAX128 / 2)));
392	}
393
394	#[test]
395	fn gdc_works() {
396		assert_eq!(gcd(10, 5), 5);
397		assert_eq!(gcd(7, 22), 1);
398	}
399
400	#[test]
401	fn lcm_works() {
402		// simple stuff
403		assert_eq!(r(3, 10).lcm(&r(4, 15)).unwrap(), 30);
404		assert_eq!(r(5, 30).lcm(&r(1, 7)).unwrap(), 210);
405		assert_eq!(r(5, 30).lcm(&r(1, 10)).unwrap(), 30);
406
407		// large numbers
408		assert_eq!(r(1_000_000_000, MAX128).lcm(&r(7_000_000_000, MAX128 - 1)), None,);
409		assert_eq!(
410			r(1_000_000_000, MAX64).lcm(&r(7_000_000_000, MAX64 - 1)),
411			Some(340282366920938463408034375210639556610),
412		);
413		const_assert!(340282366920938463408034375210639556610 < MAX128);
414		const_assert!(340282366920938463408034375210639556610 == MAX64 * (MAX64 - 1));
415	}
416
417	#[test]
418	fn add_works() {
419		// works
420		assert_eq!(r(3, 10).checked_add(r(1, 10)).unwrap(), r(2, 5));
421		assert_eq!(r(3, 10).checked_add(r(3, 7)).unwrap(), r(51, 70));
422
423		// errors
424		assert_eq!(
425			r(1, MAX128).checked_add(r(1, MAX128 - 1)),
426			Err("failed to scale to denominator"),
427		);
428		assert_eq!(
429			r(7, MAX128).checked_add(r(MAX128, MAX128)),
430			Err("overflow while adding numerators"),
431		);
432		assert_eq!(
433			r(MAX128, MAX128).checked_add(r(MAX128, MAX128)),
434			Err("overflow while adding numerators"),
435		);
436	}
437
438	#[test]
439	fn sub_works() {
440		// works
441		assert_eq!(r(3, 10).checked_sub(r(1, 10)).unwrap(), r(1, 5));
442		assert_eq!(r(6, 10).checked_sub(r(3, 7)).unwrap(), r(12, 70));
443
444		// errors
445		assert_eq!(
446			r(2, MAX128).checked_sub(r(1, MAX128 - 1)),
447			Err("failed to scale to denominator"),
448		);
449		assert_eq!(
450			r(7, MAX128).checked_sub(r(MAX128, MAX128)),
451			Err("overflow while subtracting numerators"),
452		);
453		assert_eq!(r(1, 10).checked_sub(r(2, 10)), Err("overflow while subtracting numerators"));
454	}
455
456	#[test]
457	fn ordering_and_eq_works() {
458		assert!(r(1, 2) > r(1, 3));
459		assert!(r(1, 2) > r(2, 6));
460
461		assert!(r(1, 2) < r(6, 6));
462		assert!(r(2, 1) > r(2, 6));
463
464		assert!(r(5, 10) == r(1, 2));
465		assert!(r(1, 2) == r(1, 2));
466
467		assert!(r(1, 1490000000000200000) > r(1, 1490000000000200001));
468	}
469
470	#[test]
471	fn multiply_by_rational_with_rounding_works() {
472		assert_eq!(multiply_by_rational_with_rounding(7, 2, 3, Rounding::Down).unwrap(), 7 * 2 / 3);
473		assert_eq!(
474			multiply_by_rational_with_rounding(7, 20, 30, Rounding::Down).unwrap(),
475			7 * 2 / 3
476		);
477		assert_eq!(
478			multiply_by_rational_with_rounding(20, 7, 30, Rounding::Down).unwrap(),
479			7 * 2 / 3
480		);
481
482		assert_eq!(
483			// MAX128 % 3 == 0
484			multiply_by_rational_with_rounding(MAX128, 2, 3, Rounding::Down).unwrap(),
485			MAX128 / 3 * 2,
486		);
487		assert_eq!(
488			// MAX128 % 7 == 3
489			multiply_by_rational_with_rounding(MAX128, 5, 7, Rounding::Down).unwrap(),
490			(MAX128 / 7 * 5) + (3 * 5 / 7),
491		);
492		assert_eq!(
493			// MAX128 % 7 == 3
494			multiply_by_rational_with_rounding(MAX128, 11, 13, Rounding::Down).unwrap(),
495			(MAX128 / 13 * 11) + (8 * 11 / 13),
496		);
497		assert_eq!(
498			// MAX128 % 1000 == 455
499			multiply_by_rational_with_rounding(MAX128, 555, 1000, Rounding::Down).unwrap(),
500			(MAX128 / 1000 * 555) + (455 * 555 / 1000),
501		);
502
503		assert_eq!(
504			multiply_by_rational_with_rounding(2 * MAX64 - 1, MAX64, MAX64, Rounding::Down)
505				.unwrap(),
506			2 * MAX64 - 1
507		);
508		assert_eq!(
509			multiply_by_rational_with_rounding(2 * MAX64 - 1, MAX64 - 1, MAX64, Rounding::Down)
510				.unwrap(),
511			2 * MAX64 - 3
512		);
513
514		assert_eq!(
515			multiply_by_rational_with_rounding(MAX64 + 100, MAX64_2, MAX64_2 / 2, Rounding::Down)
516				.unwrap(),
517			(MAX64 + 100) * 2,
518		);
519		assert_eq!(
520			multiply_by_rational_with_rounding(
521				MAX64 + 100,
522				MAX64_2 / 100,
523				MAX64_2 / 200,
524				Rounding::Down
525			)
526			.unwrap(),
527			(MAX64 + 100) * 2,
528		);
529
530		assert_eq!(
531			multiply_by_rational_with_rounding(
532				2u128.pow(66) - 1,
533				2u128.pow(65) - 1,
534				2u128.pow(65),
535				Rounding::Down
536			)
537			.unwrap(),
538			73786976294838206461,
539		);
540		assert_eq!(
541			multiply_by_rational_with_rounding(1_000_000_000, MAX128 / 8, MAX128 / 2, Rounding::Up)
542				.unwrap(),
543			250000000
544		);
545
546		assert_eq!(
547			multiply_by_rational_with_rounding(
548				29459999999999999988000u128,
549				1000000000000000000u128,
550				10000000000000000000u128,
551				Rounding::Down
552			)
553			.unwrap(),
554			2945999999999999998800u128
555		);
556	}
557
558	#[test]
559	fn multiply_by_rational_with_rounding_a_b_are_interchangeable() {
560		assert_eq!(
561			multiply_by_rational_with_rounding(10, MAX128, MAX128 / 2, Rounding::NearestPrefDown),
562			Some(20)
563		);
564		assert_eq!(
565			multiply_by_rational_with_rounding(MAX128, 10, MAX128 / 2, Rounding::NearestPrefDown),
566			Some(20)
567		);
568	}
569
570	#[test]
571	#[ignore]
572	fn multiply_by_rational_with_rounding_fuzzed_equation() {
573		assert_eq!(
574			multiply_by_rational_with_rounding(
575				154742576605164960401588224,
576				9223376310179529214,
577				549756068598,
578				Rounding::NearestPrefDown
579			),
580			Some(2596149632101417846585204209223679)
581		);
582	}
583}