use std::ops::{Add, Sub};
use cairo_lang_sierra::algorithm::topological_order::get_topological_ordering;
use cairo_lang_sierra::extensions::gas::{BuiltinCostsType, CostTokenType};
use cairo_lang_sierra::ids::ConcreteLibfuncId;
use cairo_lang_sierra::program::{BranchInfo, Invocation, Program, Statement, StatementIdx};
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::iterators::zip_eq3;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use itertools::zip_eq;
use crate::CostError;
use crate::gas_info::GasInfo;
use crate::objects::{BranchCost, BranchCostSign, ConstCost, PreCost, WithdrawGasBranchInfo};
type VariableValues = OrderedHashMap<(StatementIdx, CostTokenType), i64>;
pub trait CostTypeTrait:
std::fmt::Debug + Default + Clone + Eq + Add<Output = Self> + Sub<Output = Self>
{
fn min2(value1: &Self, value2: &Self) -> Self;
fn max(values: impl Iterator<Item = Self>) -> Self;
fn rectify(value: &Self) -> Self;
}
impl CostTypeTrait for i32 {
fn min2(value1: &Self, value2: &Self) -> Self {
*std::cmp::min(value1, value2)
}
fn max(values: impl Iterator<Item = Self>) -> Self {
values.max().unwrap_or_default()
}
fn rectify(value: &Self) -> Self {
std::cmp::max(*value, 0)
}
}
impl CostTypeTrait for ConstCost {
fn min2(value1: &Self, value2: &Self) -> Self {
ConstCost {
steps: std::cmp::min(value1.steps, value2.steps),
holes: std::cmp::min(value1.holes, value2.holes),
range_checks: std::cmp::min(value1.range_checks, value2.range_checks),
range_checks96: std::cmp::min(value1.range_checks96, value2.range_checks96),
}
}
fn max(values: impl Iterator<Item = Self>) -> Self {
values
.reduce(|acc, value| ConstCost {
steps: std::cmp::max(acc.steps, value.steps),
holes: std::cmp::max(acc.holes, value.holes),
range_checks: std::cmp::max(acc.range_checks, value.range_checks),
range_checks96: std::cmp::max(acc.range_checks96, value.range_checks96),
})
.unwrap_or_default()
}
fn rectify(value: &Self) -> Self {
ConstCost {
steps: std::cmp::max(value.steps, 0),
holes: std::cmp::max(value.holes, 0),
range_checks: std::cmp::max(value.range_checks, 0),
range_checks96: std::cmp::max(value.range_checks96, 0),
}
}
}
impl CostTypeTrait for PreCost {
fn min2(value1: &Self, value2: &Self) -> Self {
let map_fn = |(token_type, val1)| {
let val2 = value2.0.get(token_type)?;
Some((*token_type, *std::cmp::min(val1, val2)))
};
PreCost(value1.0.iter().filter_map(map_fn).collect())
}
fn max(values: impl Iterator<Item = Self>) -> Self {
let mut res = Self::default();
for value in values {
for (token_type, val) in value.0 {
res.0.insert(token_type, std::cmp::max(*res.0.get(&token_type).unwrap_or(&0), val));
}
}
res
}
fn rectify(value: &Self) -> Self {
let map_fn =
|(token_type, val): (&CostTokenType, &i32)| (*token_type, std::cmp::max(*val, 0));
PreCost(value.0.iter().map(map_fn).collect())
}
}
pub fn compute_costs<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
program: &Program,
get_cost_fn: &dyn Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
specific_cost_context: &SpecificCostContext,
enforced_wallet_values: &OrderedHashMap<StatementIdx, CostType>,
) -> Result<GasInfo, CostError> {
let mut context = CostContext {
program,
get_cost_fn,
enforced_wallet_values,
costs: Default::default(),
target_values: Default::default(),
};
context.prepare_wallet(specific_cost_context)?;
context.target_values = context.compute_target_values(specific_cost_context)?;
context.costs = Default::default();
context.prepare_wallet(specific_cost_context)?;
for (idx, value) in enforced_wallet_values.iter() {
if context.wallet_at_ex(idx, false).value != *value {
return Err(CostError::EnforceWalletValueFailed(*idx));
}
}
let mut variable_values = VariableValues::default();
for i in 0..program.statements.len() {
analyze_gas_statements(
&context,
specific_cost_context,
&StatementIdx(i),
&mut variable_values,
)?;
}
let function_costs = program
.funcs
.iter()
.map(|func| {
let res = SpecificCostContext::to_cost_map(context.wallet_at(&func.entry_point).value);
(func.id.clone(), res)
})
.collect();
Ok(GasInfo { variable_values, function_costs })
}
fn get_branch_requirements_dependencies(
idx: &StatementIdx,
invocation: &Invocation,
libfunc_cost: &[BranchCost],
) -> OrderedHashSet<StatementIdx> {
let mut res: OrderedHashSet<StatementIdx> = Default::default();
for (branch_info, branch_cost) in zip_eq(&invocation.branches, libfunc_cost) {
match branch_cost {
BranchCost::FunctionCost { const_cost: _, function, sign: _ } => {
res.insert(function.entry_point);
}
BranchCost::WithdrawGas(WithdrawGasBranchInfo {
success: true,
with_builtin_costs: _,
}) => {
continue;
}
_ => {}
}
res.insert(idx.next(&branch_info.target));
}
res
}
fn get_branch_requirements<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
specific_context: &SpecificCostContext,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
invocation: &Invocation,
libfunc_cost: &[BranchCost],
rectify: bool,
) -> Vec<WalletInfo<CostType>> {
zip_eq(&invocation.branches, libfunc_cost)
.map(|(branch_info, branch_cost)| {
let res = specific_context.get_branch_requirement(
wallet_at_fn,
idx,
branch_info,
branch_cost,
);
if rectify { res.rectify() } else { res }
})
.collect()
}
fn analyze_gas_statements<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
context: &CostContext<'_, CostType>,
specific_context: &SpecificCostContext,
idx: &StatementIdx,
variable_values: &mut VariableValues,
) -> Result<(), CostError> {
let Statement::Invocation(invocation) = &context.program.get_statement(idx).unwrap() else {
return Ok(());
};
let libfunc_cost: Vec<BranchCost> = context.get_cost(&invocation.libfunc_id);
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
specific_context,
&|statement_idx| context.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
false,
);
let wallet_value = context.wallet_at(idx).value;
for (branch_info, branch_cost, branch_requirement) in
zip_eq3(&invocation.branches, &libfunc_cost, &branch_requirements)
{
let future_wallet_value = context.wallet_at(&idx.next(&branch_info.target)).value;
if let BranchCost::WithdrawGas(WithdrawGasBranchInfo { success: true, .. }) = branch_cost {
let withdrawal = specific_context.get_gas_withdrawal(
idx,
branch_cost,
&wallet_value,
future_wallet_value,
)?;
for (token_type, amount) in SpecificCostContext::to_full_cost_map(withdrawal) {
assert_eq!(
variable_values.insert((*idx, token_type), std::cmp::max(amount, 0)),
None
);
assert_eq!(
variable_values.insert(
(idx.next(&branch_info.target), token_type),
std::cmp::max(-amount, 0),
),
None
);
}
} else if let BranchCost::RedepositGas = branch_cost {
let cost = wallet_value.clone() - branch_requirement.value.clone();
for (token_type, amount) in SpecificCostContext::to_full_cost_map(cost) {
assert_eq!(variable_values.insert((*idx, token_type), amount), None);
}
} else if let BranchCost::FunctionCost { sign: BranchCostSign::Add, .. } = branch_cost {
let cost = wallet_value.clone() - branch_requirement.value.clone();
for (token_type, amount) in SpecificCostContext::to_full_cost_map(cost) {
assert_eq!(variable_values.insert((*idx, token_type), amount), None);
}
} else if invocation.branches.len() > 1 {
let cost = wallet_value.clone() - branch_requirement.value.clone();
for (token_type, amount) in SpecificCostContext::to_full_cost_map(cost) {
assert_eq!(
variable_values.insert((idx.next(&branch_info.target), token_type), amount),
None
);
}
}
}
Ok(())
}
pub trait SpecificCostContextTrait<CostType: CostTypeTrait> {
fn to_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64>;
fn to_full_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64>;
fn get_gas_withdrawal(
&self,
idx: &StatementIdx,
branch_cost: &BranchCost,
wallet_value: &CostType,
future_wallet_value: CostType,
) -> Result<CostType, CostError>;
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<CostType>;
}
#[derive(Clone, Debug, Default)]
pub struct WalletInfo<CostType: CostTypeTrait> {
value: CostType,
}
impl<CostType: CostTypeTrait> WalletInfo<CostType> {
fn merge(
branch_costs: &[BranchCost],
branches: Vec<Self>,
target_value: Option<&CostType>,
) -> Self {
let n_branches = branches.len();
let mut max_value =
CostType::max(branches.iter().map(|wallet_info| wallet_info.value.clone()));
let is_branch_align = n_branches > 1;
let is_redeposit = matches!(branch_costs[..], [BranchCost::RedepositGas]);
if is_branch_align || is_redeposit {
if let Some(target_value) = target_value {
max_value = CostType::max([max_value, target_value.clone()].into_iter());
}
}
WalletInfo { value: max_value }
}
fn rectify(&self) -> Self {
Self { value: CostType::rectify(&self.value) }
}
}
impl<CostType: CostTypeTrait> From<CostType> for WalletInfo<CostType> {
fn from(value: CostType) -> Self {
WalletInfo { value }
}
}
impl<CostType: CostTypeTrait> std::ops::Add for WalletInfo<CostType> {
type Output = Self;
fn add(self, other: Self) -> Self {
WalletInfo { value: self.value + other.value }
}
}
struct CostContext<'a, CostType: CostTypeTrait> {
program: &'a Program,
get_cost_fn: &'a dyn Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
enforced_wallet_values: &'a OrderedHashMap<StatementIdx, CostType>,
costs: UnorderedHashMap<StatementIdx, WalletInfo<CostType>>,
target_values: UnorderedHashMap<StatementIdx, CostType>,
}
impl<CostType: CostTypeTrait> CostContext<'_, CostType> {
fn get_cost(&self, libfunc_id: &ConcreteLibfuncId) -> Vec<BranchCost> {
(self.get_cost_fn)(libfunc_id)
}
fn wallet_at(&self, idx: &StatementIdx) -> WalletInfo<CostType> {
self.wallet_at_ex(idx, true)
}
fn wallet_at_ex(&self, idx: &StatementIdx, with_enforced_values: bool) -> WalletInfo<CostType> {
if with_enforced_values {
if let Some(enforced_wallet_value) = self.enforced_wallet_values.get(idx) {
return WalletInfo::from(enforced_wallet_value.clone());
}
}
self.costs
.get(idx)
.unwrap_or_else(|| panic!("Wallet value for statement {idx} was not yet computed."))
.clone()
}
fn prepare_wallet<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&mut self,
specific_cost_context: &SpecificCostContext,
) -> Result<(), CostError> {
let topological_order =
compute_topological_order(self.program.statements.len(), true, |current_idx| {
match &self.program.get_statement(current_idx).unwrap() {
Statement::Return(_) => {
vec![]
}
Statement::Invocation(invocation) => {
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
get_branch_requirements_dependencies(current_idx, invocation, &libfunc_cost)
.into_iter()
.collect()
}
}
})?;
for current_idx in topological_order {
let res = self.no_cache_compute_wallet_at(¤t_idx, specific_cost_context);
self.costs.insert(current_idx, res.clone());
}
Ok(())
}
fn no_cache_compute_wallet_at<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&mut self,
idx: &StatementIdx,
specific_cost_context: &SpecificCostContext,
) -> WalletInfo<CostType> {
match &self.program.get_statement(idx).unwrap() {
Statement::Return(_) => Default::default(),
Statement::Invocation(invocation) => {
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
specific_cost_context,
&|statement_idx| self.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
true,
);
WalletInfo::merge(&libfunc_cost, branch_requirements, self.target_values.get(idx))
}
}
}
fn compute_target_values<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&self,
specific_cost_context: &SpecificCostContext,
) -> Result<UnorderedHashMap<StatementIdx, CostType>, CostError> {
let topological_order =
compute_topological_order(self.program.statements.len(), false, |current_idx| {
match self.program.get_statement(current_idx).unwrap() {
Statement::Return(_) => {
vec![]
}
Statement::Invocation(invocation) => invocation
.branches
.iter()
.map(|branch_info| current_idx.next(&branch_info.target))
.collect(),
}
})?;
let mut excess = UnorderedHashMap::<StatementIdx, CostType>::default();
let mut finalized_excess_statements = UnorderedHashSet::<StatementIdx>::default();
for idx in topological_order.iter().rev() {
self.handle_excess_at(
idx,
specific_cost_context,
&mut excess,
&mut finalized_excess_statements,
)?;
}
Ok((0..self.program.statements.len())
.map(|i| {
let idx = StatementIdx(i);
let original_wallet_value = self.wallet_at_ex(&idx, false).value;
(idx, original_wallet_value + excess.get(&idx).cloned().unwrap_or_default())
})
.collect())
}
fn handle_excess_at<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&self,
idx: &StatementIdx,
specific_cost_context: &SpecificCostContext,
excess: &mut UnorderedHashMap<StatementIdx, CostType>,
finalized_excess_statements: &mut UnorderedHashSet<StatementIdx>,
) -> Result<(), CostError> {
let wallet_value = self.wallet_at_ex(idx, false).value;
if let Some(enforced_wallet_value) = self.enforced_wallet_values.get(idx) {
excess.insert(
*idx,
CostType::rectify(&(enforced_wallet_value.clone() - wallet_value.clone())),
);
}
finalized_excess_statements.insert(*idx);
let current_excess = excess.get(idx).cloned().unwrap_or_default();
let invocation = match &self.program.get_statement(idx).unwrap() {
Statement::Invocation(invocation) => invocation,
Statement::Return(_) => {
return Ok(());
}
};
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
let branch_requirements = get_branch_requirements(
specific_cost_context,
&|statement_idx| self.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
false,
);
for (branch_info, branch_cost, branch_requirement) in
zip_eq3(&invocation.branches, &libfunc_cost, branch_requirements)
{
let branch_statement = idx.next(&branch_info.target);
if finalized_excess_statements.contains(&branch_statement) {
return Ok(());
}
let future_wallet_value = self.wallet_at(&branch_statement).value;
let mut actual_excess = current_excess.clone();
if invocation.branches.len() > 1 {
if let BranchCost::WithdrawGas(WithdrawGasBranchInfo { success: true, .. }) =
branch_cost
{
let planned_withdrawal = specific_cost_context.get_gas_withdrawal(
idx,
branch_cost,
&wallet_value,
future_wallet_value,
)?;
actual_excess = CostType::rectify(&(actual_excess - planned_withdrawal));
} else {
let additional_excess = wallet_value.clone() - branch_requirement.value;
actual_excess = actual_excess + CostType::rectify(&additional_excess);
}
} else if let BranchCost::RedepositGas = branch_cost {
actual_excess = Default::default();
} else if let BranchCost::FunctionCost { sign: BranchCostSign::Add, .. } = branch_cost {
let additional_excess = wallet_value.clone() - branch_requirement.value;
actual_excess = actual_excess + CostType::rectify(&additional_excess);
}
match excess.entry(branch_statement) {
Entry::Occupied(mut entry) => {
let current_value = entry.get();
entry.insert(CostType::min2(current_value, &actual_excess));
}
Entry::Vacant(entry) => {
entry.insert(actual_excess);
}
}
}
Ok(())
}
}
fn compute_topological_order(
n_statements: usize,
detect_cycles: bool,
dependencies_callback: impl Fn(&StatementIdx) -> Vec<StatementIdx>,
) -> Result<Vec<StatementIdx>, CostError> {
get_topological_ordering(
detect_cycles,
(0..n_statements).map(StatementIdx),
n_statements,
|idx| Ok(dependencies_callback(&idx)),
CostError::StatementOutOfBounds,
|_| CostError::UnexpectedCycle,
)
}
pub struct PreCostContext {}
impl SpecificCostContextTrait<PreCost> for PreCostContext {
fn to_cost_map(cost: PreCost) -> OrderedHashMap<CostTokenType, i64> {
let res = cost.0;
res.into_iter().map(|(token_type, val)| (token_type, val as i64)).collect()
}
fn to_full_cost_map(cost: PreCost) -> OrderedHashMap<CostTokenType, i64> {
CostTokenType::iter_precost()
.map(|token_type| (*token_type, (*cost.0.get(token_type).unwrap_or(&0)).into()))
.collect()
}
fn get_gas_withdrawal(
&self,
_idx: &StatementIdx,
_branch_cost: &BranchCost,
wallet_value: &PreCost,
future_wallet_value: PreCost,
) -> Result<PreCost, CostError> {
Ok(future_wallet_value - wallet_value.clone())
}
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<PreCost>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<PreCost> {
let branch_cost = match branch_cost {
BranchCost::Regular { const_cost: _, pre_cost } => pre_cost.clone(),
BranchCost::BranchAlign | BranchCost::RedepositGas => Default::default(),
BranchCost::FunctionCost { const_cost: _, function, sign } => {
let func_cost = wallet_at_fn(&function.entry_point).value;
match sign {
BranchCostSign::Add => PreCost::default() - func_cost,
BranchCostSign::Subtract => func_cost,
}
}
BranchCost::WithdrawGas(info) => {
if info.success {
return Default::default();
} else {
Default::default()
}
}
};
let future_wallet_value = wallet_at_fn(&idx.next(&branch_info.target));
WalletInfo::from(branch_cost) + future_wallet_value
}
}
pub trait PostCostTypeEx: CostTypeTrait + Copy {
fn from_const_cost(const_cost: &ConstCost) -> Self;
fn to_full_cost_map(self) -> OrderedHashMap<CostTokenType, i64>;
}
impl PostCostTypeEx for i32 {
fn from_const_cost(const_cost: &ConstCost) -> Self {
const_cost.cost()
}
fn to_full_cost_map(self) -> OrderedHashMap<CostTokenType, i64> {
[(CostTokenType::Const, self.into())].into_iter().collect()
}
}
impl PostCostTypeEx for ConstCost {
fn from_const_cost(const_cost: &ConstCost) -> Self {
*const_cost
}
fn to_full_cost_map(self) -> OrderedHashMap<CostTokenType, i64> {
[
(CostTokenType::Step, self.steps.into()),
(CostTokenType::Hole, self.holes.into()),
(CostTokenType::RangeCheck, self.range_checks.into()),
]
.into_iter()
.collect()
}
}
pub struct PostcostContext<'a> {
pub get_ap_change_fn: &'a dyn Fn(&StatementIdx) -> usize,
pub precost_gas_info: &'a GasInfo,
}
impl<CostType: PostCostTypeEx> SpecificCostContextTrait<CostType> for PostcostContext<'_> {
fn to_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64> {
if cost == CostType::default() { Default::default() } else { Self::to_full_cost_map(cost) }
}
fn to_full_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64> {
cost.to_full_cost_map()
}
fn get_gas_withdrawal(
&self,
idx: &StatementIdx,
branch_cost: &BranchCost,
wallet_value: &CostType,
future_wallet_value: CostType,
) -> Result<CostType, CostError> {
let BranchCost::WithdrawGas(info) = branch_cost else {
panic!("Unexpected BranchCost: {branch_cost:?}.");
};
assert!(info.success, "Unexpected BranchCost: Expected `success == true`, got {info:?}.");
let withdraw_gas_cost =
CostType::from_const_cost(&self.compute_withdraw_gas_cost(idx, info));
Ok(future_wallet_value + withdraw_gas_cost - *wallet_value)
}
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<CostType> {
let branch_cost_val = match branch_cost {
BranchCost::Regular { const_cost, pre_cost: _ } => {
CostType::from_const_cost(const_cost)
}
BranchCost::BranchAlign => {
let ap_change = (self.get_ap_change_fn)(idx);
let res = if ap_change == 0 {
ConstCost::default()
} else {
ConstCost {
steps: 1,
holes: ap_change as i32,
range_checks: 0,
range_checks96: 0,
}
};
CostType::from_const_cost(&res)
}
BranchCost::FunctionCost { const_cost, function, sign } => {
let cost = wallet_at_fn(&function.entry_point).value
+ CostType::from_const_cost(const_cost);
match sign {
BranchCostSign::Add => CostType::default() - cost,
BranchCostSign::Subtract => cost,
}
}
BranchCost::WithdrawGas(info) => {
let cost = CostType::from_const_cost(&self.compute_withdraw_gas_cost(idx, info));
if info.success {
return WalletInfo::from(cost);
}
cost
}
BranchCost::RedepositGas => {
CostType::from_const_cost(&self.compute_redeposit_gas_cost(idx))
}
};
let future_wallet_value = wallet_at_fn(&idx.next(&branch_info.target));
WalletInfo { value: branch_cost_val } + future_wallet_value
}
}
impl PostcostContext<'_> {
fn compute_withdraw_gas_cost(
&self,
idx: &StatementIdx,
info: &WithdrawGasBranchInfo,
) -> ConstCost {
info.const_cost(|token_type| {
self.precost_gas_info.variable_values[&(*idx, token_type)].into_or_panic()
})
}
fn compute_redeposit_gas_cost(&self, idx: &StatementIdx) -> ConstCost {
ConstCost::steps(
BuiltinCostsType::cost_computation_steps(false, |token_type| {
self.precost_gas_info.variable_values[&(*idx, token_type)].into_or_panic()
})
.into_or_panic(),
)
}
}