sp_arithmetic/
lib.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Minimal fixed point arithmetic primitives and types for runtime.
19
20#![cfg_attr(not(feature = "std"), no_std)]
21
22extern crate alloc;
23
24/// Copied from `sp-runtime` and documented there.
25#[macro_export]
26macro_rules! assert_eq_error_rate {
27	($x:expr, $y:expr, $error:expr $(,)?) => {
28		assert!(
29			($x) >= (($y) - ($error)) && ($x) <= (($y) + ($error)),
30			"{:?} != {:?} (with error rate {:?})",
31			$x,
32			$y,
33			$error,
34		);
35	};
36}
37
38pub mod biguint;
39pub mod fixed_point;
40pub mod helpers_128bit;
41pub mod per_things;
42pub mod rational;
43pub mod traits;
44
45pub use fixed_point::{
46	FixedI128, FixedI64, FixedPointNumber, FixedPointOperand, FixedU128, FixedU64,
47};
48pub use per_things::{
49	InnerOf, MultiplyArg, PerThing, PerU16, Perbill, Percent, Permill, Perquintill, RationalArg,
50	ReciprocalArg, Rounding, SignedRounding, UpperOf,
51};
52pub use rational::{MultiplyRational, Rational128, RationalInfinite};
53
54use alloc::vec::Vec;
55use core::{cmp::Ordering, fmt::Debug};
56use traits::{BaseArithmetic, One, SaturatedConversion, Unsigned, Zero};
57
58use codec::{Decode, DecodeWithMemTracking, Encode, MaxEncodedLen};
59use scale_info::TypeInfo;
60
61#[cfg(feature = "serde")]
62use serde::{Deserialize, Serialize};
63
64/// Arithmetic errors.
65#[derive(
66	Eq,
67	PartialEq,
68	Clone,
69	Copy,
70	Encode,
71	Decode,
72	DecodeWithMemTracking,
73	Debug,
74	TypeInfo,
75	MaxEncodedLen,
76)]
77#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
78pub enum ArithmeticError {
79	/// Underflow.
80	Underflow,
81	/// Overflow.
82	Overflow,
83	/// Division by zero.
84	DivisionByZero,
85}
86
87impl From<ArithmeticError> for &'static str {
88	fn from(e: ArithmeticError) -> &'static str {
89		match e {
90			ArithmeticError::Underflow => "An underflow would occur",
91			ArithmeticError::Overflow => "An overflow would occur",
92			ArithmeticError::DivisionByZero => "Division by zero",
93		}
94	}
95}
96
97/// Trait for comparing two numbers with an threshold.
98///
99/// Returns:
100/// - `Ordering::Greater` if `self` is greater than `other + threshold`.
101/// - `Ordering::Less` if `self` is less than `other - threshold`.
102/// - `Ordering::Equal` otherwise.
103pub trait ThresholdOrd<T> {
104	/// Compare if `self` is `threshold` greater or less than `other`.
105	fn tcmp(&self, other: &T, threshold: T) -> Ordering;
106}
107
108impl<T> ThresholdOrd<T> for T
109where
110	T: Ord + PartialOrd + Copy + Clone + traits::Zero + traits::Saturating,
111{
112	fn tcmp(&self, other: &T, threshold: T) -> Ordering {
113		// early exit.
114		if threshold.is_zero() {
115			return self.cmp(other);
116		}
117
118		let upper_bound = other.saturating_add(threshold);
119		let lower_bound = other.saturating_sub(threshold);
120
121		if upper_bound <= lower_bound {
122			// defensive only. Can never happen.
123			self.cmp(other)
124		} else {
125			// upper_bound is guaranteed now to be bigger than lower.
126			match (self.cmp(&lower_bound), self.cmp(&upper_bound)) {
127				(Ordering::Greater, Ordering::Greater) => Ordering::Greater,
128				(Ordering::Less, Ordering::Less) => Ordering::Less,
129				_ => Ordering::Equal,
130			}
131		}
132	}
133}
134
135/// A collection-like object that is made of values of type `T` and can normalize its individual
136/// values around a centric point.
137///
138/// Note that the order of items in the collection may affect the result.
139pub trait Normalizable<T> {
140	/// Normalize self around `targeted_sum`.
141	///
142	/// Only returns `Ok` if the new sum of results is guaranteed to be equal to `targeted_sum`.
143	/// Else, returns an error explaining why it failed to do so.
144	fn normalize(&self, targeted_sum: T) -> Result<Vec<T>, &'static str>;
145}
146
147macro_rules! impl_normalize_for_numeric {
148	($($numeric:ty),*) => {
149		$(
150			impl Normalizable<$numeric> for Vec<$numeric> {
151				fn normalize(&self, targeted_sum: $numeric) -> Result<Vec<$numeric>, &'static str> {
152					normalize(self.as_ref(), targeted_sum)
153				}
154			}
155		)*
156	};
157}
158
159impl_normalize_for_numeric!(u8, u16, u32, u64, u128);
160
161impl<P: PerThing> Normalizable<P> for Vec<P> {
162	fn normalize(&self, targeted_sum: P) -> Result<Vec<P>, &'static str> {
163		let uppers = self.iter().map(|p| <UpperOf<P>>::from(p.deconstruct())).collect::<Vec<_>>();
164
165		let normalized =
166			normalize(uppers.as_ref(), <UpperOf<P>>::from(targeted_sum.deconstruct()))?;
167
168		Ok(normalized
169			.into_iter()
170			.map(|i: UpperOf<P>| P::from_parts(i.saturated_into::<P::Inner>()))
171			.collect())
172	}
173}
174
175/// Normalize `input` so that the sum of all elements reaches `targeted_sum`.
176///
177/// This implementation is currently in a balanced position between being performant and accurate.
178///
179/// 1. We prefer storing original indices, and sorting the `input` only once. This will save the
180///    cost of sorting per round at the cost of a little bit of memory.
181/// 2. The granularity of increment/decrements is determined by the number of elements in `input`
182///    and their sum difference with `targeted_sum`, namely `diff = diff(sum(input), target_sum)`.
183///    This value is then distributed into `per_round = diff / input.len()` and `leftover = diff %
184///    round`. First, per_round is applied to all elements of input, and then we move to leftover,
185///    in which case we add/subtract 1 by 1 until `leftover` is depleted.
186///
187/// When the sum is less than the target, the above approach always holds. In this case, then each
188/// individual element is also less than target. Thus, by adding `per_round` to each item, neither
189/// of them can overflow the numeric bound of `T`. In fact, neither of the can go beyond
190/// `target_sum`*.
191///
192/// If sum is more than target, there is small twist. The subtraction of `per_round`
193/// form each element might go below zero. In this case, we saturate and add the error to the
194/// `leftover` value. This ensures that the result will always stay accurate, yet it might cause the
195/// execution to become increasingly slow, since leftovers are applied one by one.
196///
197/// All in all, the complicated case above is rare to happen in most use cases within this repo ,
198/// hence we opt for it due to its simplicity.
199///
200/// This function will return an error is if length of `input` cannot fit in `T`, or if `sum(input)`
201/// cannot fit inside `T`.
202///
203/// * This proof is used in the implementation as well.
204pub fn normalize<T>(input: &[T], targeted_sum: T) -> Result<Vec<T>, &'static str>
205where
206	T: Clone + Copy + Ord + BaseArithmetic + Unsigned + Debug,
207{
208	// compute sum and return error if failed.
209	let mut sum = T::zero();
210	for t in input.iter() {
211		sum = sum.checked_add(t).ok_or("sum of input cannot fit in `T`")?;
212	}
213
214	// convert count and return error if failed.
215	let count = input.len();
216	let count_t: T = count.try_into().map_err(|_| "length of `inputs` cannot fit in `T`")?;
217
218	// Nothing to do here.
219	if count.is_zero() {
220		return Ok(Vec::<T>::new());
221	}
222
223	let diff = targeted_sum.max(sum) - targeted_sum.min(sum);
224	if diff.is_zero() {
225		return Ok(input.to_vec());
226	}
227
228	let needs_bump = targeted_sum > sum;
229	let per_round = diff / count_t;
230	let mut leftover = diff % count_t;
231
232	// sort output once based on diff. This will require more data transfer and saving original
233	// index, but we sort only twice instead: once now and once at the very end.
234	let mut output_with_idx = input.iter().cloned().enumerate().collect::<Vec<(usize, T)>>();
235	output_with_idx.sort_by_key(|x| x.1);
236
237	if needs_bump {
238		// must increase the values a bit. Bump from the min element. Index of minimum is now zero
239		// because we did a sort. If at any point the min goes greater or equal the `max_threshold`,
240		// we move to the next minimum.
241		let mut min_index = 0;
242		// at this threshold we move to next index.
243		let threshold = targeted_sum / count_t;
244
245		if !per_round.is_zero() {
246			for _ in 0..count {
247				output_with_idx[min_index].1 = output_with_idx[min_index]
248					.1
249					.checked_add(&per_round)
250					.expect("Proof provided in the module doc; qed.");
251				if output_with_idx[min_index].1 >= threshold {
252					min_index += 1;
253					min_index %= count;
254				}
255			}
256		}
257
258		// continue with the previous min_index
259		while !leftover.is_zero() {
260			output_with_idx[min_index].1 = output_with_idx[min_index]
261				.1
262				.checked_add(&T::one())
263				.expect("Proof provided in the module doc; qed.");
264			if output_with_idx[min_index].1 >= threshold {
265				min_index += 1;
266				min_index %= count;
267			}
268			leftover -= One::one();
269		}
270	} else {
271		// must decrease the stakes a bit. decrement from the max element. index of maximum is now
272		// last. if at any point the max goes less or equal the `min_threshold`, we move to the next
273		// maximum.
274		let mut max_index = count - 1;
275		// at this threshold we move to next index.
276		let threshold = output_with_idx
277			.first()
278			.expect("length of input is greater than zero; it must have a first; qed")
279			.1;
280
281		if !per_round.is_zero() {
282			for _ in 0..count {
283				output_with_idx[max_index].1 =
284					output_with_idx[max_index].1.checked_sub(&per_round).unwrap_or_else(|| {
285						let remainder = per_round - output_with_idx[max_index].1;
286						leftover += remainder;
287						output_with_idx[max_index].1.saturating_sub(per_round)
288					});
289				if output_with_idx[max_index].1 <= threshold {
290					max_index = max_index.checked_sub(1).unwrap_or(count - 1);
291				}
292			}
293		}
294
295		// continue with the previous max_index
296		while !leftover.is_zero() {
297			if let Some(next) = output_with_idx[max_index].1.checked_sub(&One::one()) {
298				output_with_idx[max_index].1 = next;
299				if output_with_idx[max_index].1 <= threshold {
300					max_index = max_index.checked_sub(1).unwrap_or(count - 1);
301				}
302				leftover -= One::one();
303			} else {
304				max_index = max_index.checked_sub(1).unwrap_or(count - 1);
305			}
306		}
307	}
308
309	debug_assert_eq!(
310		output_with_idx.iter().fold(T::zero(), |acc, (_, x)| acc + *x),
311		targeted_sum,
312		"sum({:?}) != {:?}",
313		output_with_idx,
314		targeted_sum
315	);
316
317	// sort again based on the original index.
318	output_with_idx.sort_by_key(|x| x.0);
319	Ok(output_with_idx.into_iter().map(|(_, t)| t).collect())
320}
321
322#[cfg(test)]
323mod normalize_tests {
324	use super::*;
325
326	#[test]
327	fn work_for_all_types() {
328		macro_rules! test_for {
329			($type:ty) => {
330				assert_eq!(
331					normalize(vec![8 as $type, 9, 7, 10].as_ref(), 40).unwrap(),
332					vec![10, 10, 10, 10],
333				);
334			};
335		}
336		// it should work for all types as long as the length of vector can be converted to T.
337		test_for!(u128);
338		test_for!(u64);
339		test_for!(u32);
340		test_for!(u16);
341		test_for!(u8);
342	}
343
344	#[test]
345	fn fails_on_if_input_sum_large() {
346		assert!(normalize(vec![1u8; 255].as_ref(), 10).is_ok());
347		assert_eq!(normalize(vec![1u8; 256].as_ref(), 10), Err("sum of input cannot fit in `T`"));
348	}
349
350	#[test]
351	fn does_not_fail_on_subtraction_overflow() {
352		assert_eq!(normalize(vec![1u8, 100, 100].as_ref(), 10).unwrap(), vec![1, 9, 0]);
353		assert_eq!(normalize(vec![1u8, 8, 9].as_ref(), 1).unwrap(), vec![0, 1, 0]);
354	}
355
356	#[test]
357	fn works_for_vec() {
358		assert_eq!(vec![8u32, 9, 7, 10].normalize(40).unwrap(), vec![10u32, 10, 10, 10]);
359	}
360
361	#[test]
362	fn works_for_per_thing() {
363		assert_eq!(
364			vec![Perbill::from_percent(33), Perbill::from_percent(33), Perbill::from_percent(33)]
365				.normalize(Perbill::one())
366				.unwrap(),
367			vec![
368				Perbill::from_parts(333333334),
369				Perbill::from_parts(333333333),
370				Perbill::from_parts(333333333)
371			]
372		);
373
374		assert_eq!(
375			vec![Perbill::from_percent(20), Perbill::from_percent(15), Perbill::from_percent(30)]
376				.normalize(Perbill::one())
377				.unwrap(),
378			vec![
379				Perbill::from_parts(316666668),
380				Perbill::from_parts(383333332),
381				Perbill::from_parts(300000000)
382			]
383		);
384	}
385
386	#[test]
387	fn can_work_for_peru16() {
388		// Peru16 is a rather special case; since inner type is exactly the same as capacity, we
389		// could have a situation where the sum cannot be calculated in the inner type. Calculating
390		// using the upper type of the per_thing should assure this to be okay.
391		assert_eq!(
392			vec![PerU16::from_percent(40), PerU16::from_percent(40), PerU16::from_percent(40)]
393				.normalize(PerU16::one())
394				.unwrap(),
395			vec![
396				PerU16::from_parts(21845), // 33%
397				PerU16::from_parts(21845), // 33%
398				PerU16::from_parts(21845)  // 33%
399			]
400		);
401	}
402
403	#[test]
404	fn normalize_works_all_le() {
405		assert_eq!(normalize(vec![8u32, 9, 7, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
406
407		assert_eq!(normalize(vec![7u32, 7, 7, 7].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
408
409		assert_eq!(normalize(vec![7u32, 7, 7, 10].as_ref(), 40).unwrap(), vec![11, 11, 8, 10]);
410
411		assert_eq!(normalize(vec![7u32, 8, 7, 10].as_ref(), 40).unwrap(), vec![11, 8, 11, 10]);
412
413		assert_eq!(normalize(vec![7u32, 7, 8, 10].as_ref(), 40).unwrap(), vec![11, 11, 8, 10]);
414	}
415
416	#[test]
417	fn normalize_works_some_ge() {
418		assert_eq!(normalize(vec![8u32, 11, 9, 10].as_ref(), 40).unwrap(), vec![10, 11, 9, 10]);
419	}
420
421	#[test]
422	fn always_inc_min() {
423		assert_eq!(normalize(vec![10u32, 7, 10, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
424		assert_eq!(normalize(vec![10u32, 10, 7, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
425		assert_eq!(normalize(vec![10u32, 10, 10, 7].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
426	}
427
428	#[test]
429	fn normalize_works_all_ge() {
430		assert_eq!(normalize(vec![12u32, 11, 13, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
431
432		assert_eq!(normalize(vec![13u32, 13, 13, 13].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
433
434		assert_eq!(normalize(vec![13u32, 13, 13, 10].as_ref(), 40).unwrap(), vec![12, 9, 9, 10]);
435
436		assert_eq!(normalize(vec![13u32, 12, 13, 10].as_ref(), 40).unwrap(), vec![9, 12, 9, 10]);
437
438		assert_eq!(normalize(vec![13u32, 13, 12, 10].as_ref(), 40).unwrap(), vec![9, 9, 12, 10]);
439	}
440}
441
442#[cfg(test)]
443mod per_and_fixed_examples {
444	use super::*;
445
446	#[docify::export]
447	#[test]
448	fn percent_mult() {
449		let percent = Percent::from_rational(5u32, 100u32); // aka, 5%
450		let five_percent_of_100 = percent * 100u32; // 5% of 100 is 5.
451		assert_eq!(five_percent_of_100, 5)
452	}
453	#[docify::export]
454	#[test]
455	fn perbill_example() {
456		let p = Perbill::from_percent(80);
457		// 800000000 bil, or a representative of 0.800000000.
458		// Precision is in the billions place.
459		assert_eq!(p.deconstruct(), 800000000);
460	}
461
462	#[docify::export]
463	#[test]
464	fn percent_example() {
465		let percent = Percent::from_rational(190u32, 400u32);
466		assert_eq!(percent.deconstruct(), 47);
467	}
468
469	#[docify::export]
470	#[test]
471	fn fixed_u64_block_computation_example() {
472		// Calculate a very rudimentary on-chain price from supply / demand
473		// Supply: Cores available per block
474		// Demand: Cores being ordered per block
475		let price = FixedU64::from_rational(5u128, 10u128);
476
477		// 0.5 DOT per core
478		assert_eq!(price, FixedU64::from_float(0.5));
479
480		// Now, the story has changed - lots of demand means we buy as many cores as there
481		// available.  This also means that price goes up! For the sake of simplicity, we don't care
482		// about who gets a core - just about our very simple price model
483
484		// Calculate a very rudimentary on-chain price from supply / demand
485		// Supply: Cores available per block
486		// Demand: Cores being ordered per block
487		let price = FixedU64::from_rational(19u128, 10u128);
488
489		// 1.9 DOT per core
490		assert_eq!(price, FixedU64::from_float(1.9));
491	}
492
493	#[docify::export]
494	#[test]
495	fn fixed_u64() {
496		// The difference between this and perthings is perthings operates within the relam of [0,
497		// 1] In cases where we need > 1, we can used fixed types such as FixedU64
498
499		let rational_1 = FixedU64::from_rational(10, 5); //" 200%" aka 2.
500		let rational_2 = FixedU64::from_rational_with_rounding(5, 10, Rounding::Down); // "50%" aka 0.50...
501
502		assert_eq!(rational_1, (2u64).into());
503		assert_eq!(rational_2.into_perbill(), Perbill::from_float(0.5));
504	}
505
506	#[docify::export]
507	#[test]
508	fn fixed_u64_operation_example() {
509		let rational_1 = FixedU64::from_rational(10, 5); // "200%" aka 2.
510		let rational_2 = FixedU64::from_rational(8, 5); // "160%" aka 1.6.
511
512		let addition = rational_1 + rational_2;
513		let multiplication = rational_1 * rational_2;
514		let division = rational_1 / rational_2;
515		let subtraction = rational_1 - rational_2;
516
517		assert_eq!(addition, FixedU64::from_float(3.6));
518		assert_eq!(multiplication, FixedU64::from_float(3.2));
519		assert_eq!(division, FixedU64::from_float(1.25));
520		assert_eq!(subtraction, FixedU64::from_float(0.4));
521	}
522}
523
524#[cfg(test)]
525mod threshold_compare_tests {
526	use super::*;
527	use crate::traits::Saturating;
528	use core::cmp::Ordering;
529
530	#[test]
531	fn epsilon_ord_works() {
532		let b = 115u32;
533		let e = Perbill::from_percent(10).mul_ceil(b);
534
535		// [115 - 11,5 (103,5), 115 + 11,5 (126,5)] is all equal
536		assert_eq!((103u32).tcmp(&b, e), Ordering::Equal);
537		assert_eq!((104u32).tcmp(&b, e), Ordering::Equal);
538		assert_eq!((115u32).tcmp(&b, e), Ordering::Equal);
539		assert_eq!((120u32).tcmp(&b, e), Ordering::Equal);
540		assert_eq!((126u32).tcmp(&b, e), Ordering::Equal);
541		assert_eq!((127u32).tcmp(&b, e), Ordering::Equal);
542
543		assert_eq!((128u32).tcmp(&b, e), Ordering::Greater);
544		assert_eq!((102u32).tcmp(&b, e), Ordering::Less);
545	}
546
547	#[test]
548	fn epsilon_ord_works_with_small_epc() {
549		let b = 115u32;
550		// way less than 1 percent. threshold will be zero. Result should be same as normal ord.
551		let e = Perbill::from_parts(100) * b;
552
553		// [115 - 11,5 (103,5), 115 + 11,5 (126,5)] is all equal
554		assert_eq!((103u32).tcmp(&b, e), (103u32).cmp(&b));
555		assert_eq!((104u32).tcmp(&b, e), (104u32).cmp(&b));
556		assert_eq!((115u32).tcmp(&b, e), (115u32).cmp(&b));
557		assert_eq!((120u32).tcmp(&b, e), (120u32).cmp(&b));
558		assert_eq!((126u32).tcmp(&b, e), (126u32).cmp(&b));
559		assert_eq!((127u32).tcmp(&b, e), (127u32).cmp(&b));
560
561		assert_eq!((128u32).tcmp(&b, e), (128u32).cmp(&b));
562		assert_eq!((102u32).tcmp(&b, e), (102u32).cmp(&b));
563	}
564
565	#[test]
566	fn peru16_rational_does_not_overflow() {
567		// A historical example that will panic only for per_thing type that are created with
568		// maximum capacity of their type, e.g. PerU16.
569		let _ = PerU16::from_rational(17424870u32, 17424870);
570	}
571
572	#[test]
573	fn saturating_mul_works() {
574		assert_eq!(Saturating::saturating_mul(2, i32::MIN), i32::MIN);
575		assert_eq!(Saturating::saturating_mul(2, i32::MAX), i32::MAX);
576	}
577
578	#[test]
579	fn saturating_pow_works() {
580		assert_eq!(Saturating::saturating_pow(i32::MIN, 0), 1);
581		assert_eq!(Saturating::saturating_pow(i32::MAX, 0), 1);
582		assert_eq!(Saturating::saturating_pow(i32::MIN, 3), i32::MIN);
583		assert_eq!(Saturating::saturating_pow(i32::MIN, 2), i32::MAX);
584		assert_eq!(Saturating::saturating_pow(i32::MAX, 2), i32::MAX);
585	}
586}