1use crate::{
21 traits::{AtLeast32BitUnsigned, SaturatedConversion},
22 Perbill,
23};
24use core::ops::Sub;
25use scale_info::TypeInfo;
26
27#[derive(PartialEq, Eq, sp_core::RuntimeDebug, TypeInfo)]
29pub struct PiecewiseLinear<'a> {
30 pub points: &'a [(Perbill, Perbill)],
32 pub maximum: Perbill,
34}
35
36fn abs_sub<N: Ord + Sub<Output = N> + Clone>(a: N, b: N) -> N where {
37 a.clone().max(b.clone()) - a.min(b)
38}
39
40impl<'a> PiecewiseLinear<'a> {
41 pub fn calculate_for_fraction_times_denominator<N>(&self, n: N, d: N) -> N
43 where
44 N: AtLeast32BitUnsigned + Clone,
45 {
46 let n = n.min(d.clone());
47
48 if self.points.is_empty() {
49 return N::zero()
50 }
51
52 let next_point_index = self.points.iter().position(|p| n < p.0 * d.clone());
53
54 let (prev, next) = if let Some(next_point_index) = next_point_index {
55 if let Some(previous_point_index) = next_point_index.checked_sub(1) {
56 (self.points[previous_point_index], self.points[next_point_index])
57 } else {
58 return self.points.first().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
60 }
61 } else {
62 return self.points.last().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
64 };
65
66 let delta_y = multiply_by_rational_saturating(
67 abs_sub(n.clone(), prev.0 * d.clone()),
68 abs_sub(next.1.deconstruct(), prev.1.deconstruct()),
69 next.0.deconstruct().saturating_sub(prev.0.deconstruct()),
71 );
72
73 if (n > prev.0 * d.clone()) == (next.1.deconstruct() > prev.1.deconstruct()) {
75 (prev.1 * d).saturating_add(delta_y)
76 } else {
78 (prev.1 * d).saturating_sub(delta_y)
79 }
80 }
81}
82
83fn multiply_by_rational_saturating<N>(value: N, p: u32, q: u32) -> N
87where
88 N: AtLeast32BitUnsigned + Clone,
89{
90 let q = q.max(1);
91
92 let result_divisor_part = (value.clone() / q.into()).saturating_mul(p.into());
94
95 let result_remainder_part = {
96 let rem = value % q.into();
97
98 let rem_u32 = rem.saturated_into::<u32>();
100
101 let rem_part = rem_u32 as u64 * p as u64 / q as u64;
103
104 rem_part.saturated_into::<N>()
106 };
107
108 result_divisor_part.saturating_add(result_remainder_part)
110}
111
112#[test]
113fn test_multiply_by_rational_saturating() {
114 let div = 100u32;
115 for value in 0..=div {
116 for p in 0..=div {
117 for q in 1..=div {
118 let value: u64 =
119 (value as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
120 let p = (p as u64 * u32::MAX as u64 / div as u64).try_into().unwrap();
121 let q = (q as u64 * u32::MAX as u64 / div as u64).try_into().unwrap();
122
123 assert_eq!(
124 multiply_by_rational_saturating(value, p, q),
125 (value as u128 * p as u128 / q as u128).try_into().unwrap_or(u64::MAX)
126 );
127 }
128 }
129 }
130}
131
132#[test]
133fn test_calculate_for_fraction_times_denominator() {
134 let curve = PiecewiseLinear {
135 points: &[
136 (Perbill::from_parts(0_000_000_000), Perbill::from_parts(0_500_000_000)),
137 (Perbill::from_parts(0_500_000_000), Perbill::from_parts(1_000_000_000)),
138 (Perbill::from_parts(1_000_000_000), Perbill::from_parts(0_000_000_000)),
139 ],
140 maximum: Perbill::from_parts(1_000_000_000),
141 };
142
143 pub fn formal_calculate_for_fraction_times_denominator(n: u64, d: u64) -> u64 {
144 if n <= Perbill::from_parts(0_500_000_000) * d {
145 n + d / 2
146 } else {
147 (d as u128 * 2 - n as u128 * 2).try_into().unwrap()
148 }
149 }
150
151 let div = 100u32;
152 for d in 0..=div {
153 for n in 0..=d {
154 let d: u64 = (d as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
155 let n: u64 = (n as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
156
157 let res = curve.calculate_for_fraction_times_denominator(n, d);
158 let expected = formal_calculate_for_fraction_times_denominator(n, d);
159
160 assert!(abs_sub(res, expected) <= 1);
161 }
162 }
163}