sp_runtime/
curve.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Provides some utilities to define a piecewise linear function.
19
20use crate::{
21	traits::{AtLeast32BitUnsigned, SaturatedConversion},
22	Perbill,
23};
24use core::ops::Sub;
25use scale_info::TypeInfo;
26
27/// Piecewise Linear function in [0, 1] -> [0, 1].
28#[derive(PartialEq, Eq, sp_core::RuntimeDebug, TypeInfo)]
29pub struct PiecewiseLinear<'a> {
30	/// Array of points. Must be in order from the lowest abscissas to the highest.
31	pub points: &'a [(Perbill, Perbill)],
32	/// The maximum value that can be returned.
33	pub maximum: Perbill,
34}
35
36fn abs_sub<N: Ord + Sub<Output = N> + Clone>(a: N, b: N) -> N where {
37	a.clone().max(b.clone()) - a.min(b)
38}
39
40impl<'a> PiecewiseLinear<'a> {
41	/// Compute `f(n/d)*d` with `n <= d`. This is useful to avoid loss of precision.
42	pub fn calculate_for_fraction_times_denominator<N>(&self, n: N, d: N) -> N
43	where
44		N: AtLeast32BitUnsigned + Clone,
45	{
46		let n = n.min(d.clone());
47
48		if self.points.is_empty() {
49			return N::zero()
50		}
51
52		let next_point_index = self.points.iter().position(|p| n < p.0 * d.clone());
53
54		let (prev, next) = if let Some(next_point_index) = next_point_index {
55			if let Some(previous_point_index) = next_point_index.checked_sub(1) {
56				(self.points[previous_point_index], self.points[next_point_index])
57			} else {
58				// There is no previous points, take first point ordinate
59				return self.points.first().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
60			}
61		} else {
62			// There is no next points, take last point ordinate
63			return self.points.last().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
64		};
65
66		let delta_y = multiply_by_rational_saturating(
67			abs_sub(n.clone(), prev.0 * d.clone()),
68			abs_sub(next.1.deconstruct(), prev.1.deconstruct()),
69			// Must not saturate as prev abscissa > next abscissa
70			next.0.deconstruct().saturating_sub(prev.0.deconstruct()),
71		);
72
73		// If both subtractions are same sign then result is positive
74		if (n > prev.0 * d.clone()) == (next.1.deconstruct() > prev.1.deconstruct()) {
75			(prev.1 * d).saturating_add(delta_y)
76		// Otherwise result is negative
77		} else {
78			(prev.1 * d).saturating_sub(delta_y)
79		}
80	}
81}
82
83// Compute value * p / q.
84// This is guaranteed not to overflow on whatever values nor lose precision.
85// `q` must be superior to zero.
86fn multiply_by_rational_saturating<N>(value: N, p: u32, q: u32) -> N
87where
88	N: AtLeast32BitUnsigned + Clone,
89{
90	let q = q.max(1);
91
92	// Mul can saturate if p > q
93	let result_divisor_part = (value.clone() / q.into()).saturating_mul(p.into());
94
95	let result_remainder_part = {
96		let rem = value % q.into();
97
98		// Fits into u32 because q is u32 and remainder < q
99		let rem_u32 = rem.saturated_into::<u32>();
100
101		// Multiplication fits into u64 as both term are u32
102		let rem_part = rem_u32 as u64 * p as u64 / q as u64;
103
104		// Can saturate if p > q
105		rem_part.saturated_into::<N>()
106	};
107
108	// Can saturate if p > q
109	result_divisor_part.saturating_add(result_remainder_part)
110}
111
112#[test]
113fn test_multiply_by_rational_saturating() {
114	let div = 100u32;
115	for value in 0..=div {
116		for p in 0..=div {
117			for q in 1..=div {
118				let value: u64 =
119					(value as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
120				let p = (p as u64 * u32::MAX as u64 / div as u64).try_into().unwrap();
121				let q = (q as u64 * u32::MAX as u64 / div as u64).try_into().unwrap();
122
123				assert_eq!(
124					multiply_by_rational_saturating(value, p, q),
125					(value as u128 * p as u128 / q as u128).try_into().unwrap_or(u64::MAX)
126				);
127			}
128		}
129	}
130}
131
132#[test]
133fn test_calculate_for_fraction_times_denominator() {
134	let curve = PiecewiseLinear {
135		points: &[
136			(Perbill::from_parts(0_000_000_000), Perbill::from_parts(0_500_000_000)),
137			(Perbill::from_parts(0_500_000_000), Perbill::from_parts(1_000_000_000)),
138			(Perbill::from_parts(1_000_000_000), Perbill::from_parts(0_000_000_000)),
139		],
140		maximum: Perbill::from_parts(1_000_000_000),
141	};
142
143	pub fn formal_calculate_for_fraction_times_denominator(n: u64, d: u64) -> u64 {
144		if n <= Perbill::from_parts(0_500_000_000) * d {
145			n + d / 2
146		} else {
147			(d as u128 * 2 - n as u128 * 2).try_into().unwrap()
148		}
149	}
150
151	let div = 100u32;
152	for d in 0..=div {
153		for n in 0..=d {
154			let d: u64 = (d as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
155			let n: u64 = (n as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
156
157			let res = curve.calculate_for_fraction_times_denominator(n, d);
158			let expected = formal_calculate_for_fraction_times_denominator(n, d);
159
160			assert!(abs_sub(res, expected) <= 1);
161		}
162	}
163}