soroban_env_host/budget/
model.rs1use crate::{
2 xdr::{ContractCostParamEntry, ScErrorCode, ScErrorType},
3 HostError,
4};
5use core::fmt::{Debug, Display};
6
7pub trait HostCostModel {
31 fn evaluate(&self, iterations: u64, input: Option<u64>) -> Result<u64, HostError>;
32
33 #[cfg(any(test, feature = "testutils", feature = "bench"))]
34 fn reset(&mut self);
35}
36
37const COST_MODEL_LIN_TERM_SCALE_BITS: u32 = 7;
42
43#[derive(Clone, Copy, Default, Debug)]
45pub struct ScaledU64(pub(crate) u64);
46
47impl ScaledU64 {
48 pub const fn from_unscaled_u64(u: u64) -> Self {
49 ScaledU64(u << COST_MODEL_LIN_TERM_SCALE_BITS)
50 }
51
52 pub const fn unscale(self) -> u64 {
53 self.0 >> COST_MODEL_LIN_TERM_SCALE_BITS
54 }
55
56 pub const fn is_zero(&self) -> bool {
57 self.0 == 0
58 }
59
60 pub const fn saturating_mul(&self, rhs: u64) -> Self {
61 ScaledU64(self.0.saturating_mul(rhs))
62 }
63
64 pub const fn safe_div(&self, rhs: u64) -> Self {
65 ScaledU64(match self.0.checked_div(rhs) {
66 Some(v) => v,
67 None => 0,
68 })
69 }
70}
71
72impl Display for ScaledU64 {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 write!(f, "{}", self.0)
75 }
76}
77
78#[cfg(feature = "bench")]
79impl From<f64> for ScaledU64 {
80 fn from(unscaled: f64) -> Self {
81 let scaled = unscaled * ((1 << COST_MODEL_LIN_TERM_SCALE_BITS) as f64);
82 ScaledU64(scaled.ceil() as u64)
84 }
85}
86
87#[derive(Clone, Copy, Debug, Default)]
88pub struct MeteredCostComponent {
89 pub const_term: u64,
90 pub lin_term: ScaledU64,
91}
92
93impl TryFrom<&ContractCostParamEntry> for MeteredCostComponent {
94 type Error = HostError;
95
96 fn try_from(entry: &ContractCostParamEntry) -> Result<Self, Self::Error> {
97 if entry.const_term < 0 || entry.linear_term < 0 {
98 return Err((ScErrorType::Context, ScErrorCode::InvalidInput).into());
99 }
100 Ok(MeteredCostComponent {
101 const_term: entry.const_term as u64,
102 lin_term: ScaledU64(entry.linear_term as u64),
103 })
104 }
105}
106
107impl TryFrom<ContractCostParamEntry> for MeteredCostComponent {
108 type Error = HostError;
109
110 fn try_from(entry: ContractCostParamEntry) -> Result<Self, Self::Error> {
111 Self::try_from(&entry)
112 }
113}
114
115impl HostCostModel for MeteredCostComponent {
116 fn evaluate(&self, iterations: u64, input: Option<u64>) -> Result<u64, HostError> {
117 let const_term = self.const_term.saturating_mul(iterations);
118 match input {
119 Some(input) => {
120 let mut res = const_term;
121 if !self.lin_term.is_zero() {
122 let lin_cost = self
123 .lin_term
124 .saturating_mul(input)
125 .saturating_mul(iterations);
126 res = res.saturating_add(lin_cost.unscale())
127 }
128 Ok(res)
129 }
130 None => Ok(const_term),
131 }
132 }
133
134 #[cfg(any(test, feature = "testutils", feature = "bench"))]
135 fn reset(&mut self) {
136 self.const_term = 0;
137 self.lin_term = ScaledU64(0);
138 }
139}
140
141mod test {
142 #[allow(unused)]
143 use super::{HostCostModel, MeteredCostComponent, ScaledU64};
144
145 #[test]
146 fn test_model_evaluation_with_rounding() {
147 let test_model = MeteredCostComponent {
148 const_term: 3,
149 lin_term: ScaledU64(5),
150 };
151 assert_eq!(3, test_model.evaluate(1, Some(1)).unwrap());
154 assert_eq!(4, test_model.evaluate(1, Some(26)).unwrap());
157 assert_eq!(79, test_model.evaluate(26, Some(1)).unwrap());
160 }
161}