use std::collections::hash_map;
use std::ops::{Add, Sub};
use cairo_lang_sierra::algorithm::topological_order::get_topological_ordering;
use cairo_lang_sierra::extensions::gas::{BuiltinCostWithdrawGasLibfunc, CostTokenType};
use cairo_lang_sierra::ids::ConcreteLibfuncId;
use cairo_lang_sierra::program::{BranchInfo, Invocation, Program, Statement, StatementIdx};
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::iterators::zip_eq3;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use itertools::zip_eq;
use crate::gas_info::GasInfo;
use crate::objects::{BranchCost, ConstCost, PreCost};
use crate::CostError;
type VariableValues = OrderedHashMap<(StatementIdx, CostTokenType), i64>;
pub trait CostTypeTrait:
std::fmt::Debug + Default + Clone + Eq + Add<Output = Self> + Sub<Output = Self>
{
fn min2(value1: &Self, value2: &Self) -> Self;
fn max(values: impl Iterator<Item = Self>) -> Self;
fn rectify(value: &Self) -> Self;
}
impl CostTypeTrait for i32 {
fn min2(value1: &Self, value2: &Self) -> Self {
*std::cmp::min(value1, value2)
}
fn max(values: impl Iterator<Item = Self>) -> Self {
values.max().unwrap_or_default()
}
fn rectify(value: &Self) -> Self {
std::cmp::max(*value, 0)
}
}
impl CostTypeTrait for PreCost {
fn min2(value1: &Self, value2: &Self) -> Self {
let map_fn = |(token_type, val1)| {
let val2 = value2.0.get(token_type)?;
Some((*token_type, *std::cmp::min(val1, val2)))
};
PreCost(value1.0.iter().filter_map(map_fn).collect())
}
fn max(values: impl Iterator<Item = Self>) -> Self {
let mut res = Self::default();
for value in values {
for (token_type, val) in value.0 {
res.0.insert(token_type, std::cmp::max(*res.0.get(&token_type).unwrap_or(&0), val));
}
}
res
}
fn rectify(value: &Self) -> Self {
let map_fn =
|(token_type, val): (&CostTokenType, &i32)| (*token_type, std::cmp::max(*val, 0));
PreCost(value.0.iter().map(map_fn).collect())
}
}
pub fn compute_costs<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
program: &Program,
get_cost_fn: &dyn Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
specific_cost_context: &SpecificCostContext,
enforced_wallet_values: &OrderedHashMap<StatementIdx, CostType>,
) -> Result<GasInfo, CostError> {
let mut context = CostContext {
program,
get_cost_fn,
enforced_wallet_values,
costs: Default::default(),
target_values: Default::default(),
};
context.prepare_wallet(specific_cost_context)?;
if SpecificCostContext::should_handle_excess() {
context.target_values = context.compute_target_values(specific_cost_context)?;
context.costs = Default::default();
context.prepare_wallet(specific_cost_context)?;
for (idx, value) in enforced_wallet_values.iter() {
if context.wallet_at_ex(idx, false).value != *value {
return Err(CostError::EnforceWalletValueFailed(*idx));
}
}
} else {
assert!(
enforced_wallet_values.is_empty(),
"enforced_wallet_values cannot be used when should_handle_excess is false."
);
}
let mut variable_values = VariableValues::default();
for i in 0..program.statements.len() {
analyze_gas_statements(
&context,
specific_cost_context,
&StatementIdx(i),
&mut variable_values,
);
}
let function_costs = program
.funcs
.iter()
.map(|func| {
let res = SpecificCostContext::to_cost_map(context.wallet_at(&func.entry_point).value);
(func.id.clone(), res)
})
.collect();
Ok(GasInfo { variable_values, function_costs })
}
fn get_branch_requirements_dependencies(
idx: &StatementIdx,
invocation: &Invocation,
libfunc_cost: &[BranchCost],
) -> OrderedHashSet<StatementIdx> {
let mut res: OrderedHashSet<StatementIdx> = Default::default();
for (branch_info, branch_cost) in zip_eq(&invocation.branches, libfunc_cost) {
match branch_cost {
BranchCost::FunctionCall { const_cost: _, function } => {
res.insert(function.entry_point);
}
BranchCost::WithdrawGas { const_cost: _, success: true, with_builtin_costs: _ } => {
continue;
}
_ => {}
}
res.insert(idx.next(&branch_info.target));
}
res
}
fn get_branch_requirements<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
specific_context: &SpecificCostContext,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
invocation: &Invocation,
libfunc_cost: &[BranchCost],
) -> Vec<WalletInfo<CostType>> {
zip_eq(&invocation.branches, libfunc_cost)
.map(|(branch_info, branch_cost)| {
specific_context.get_branch_requirement(wallet_at_fn, idx, branch_info, branch_cost)
})
.collect()
}
fn analyze_gas_statements<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
context: &CostContext<'_, CostType>,
specific_context: &SpecificCostContext,
idx: &StatementIdx,
variable_values: &mut VariableValues,
) {
let Statement::Invocation(invocation) = &context.program.get_statement(idx).unwrap() else {
return;
};
let libfunc_cost: Vec<BranchCost> = context.get_cost(&invocation.libfunc_id);
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
specific_context,
&|statement_idx| context.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
);
let wallet_value = context.wallet_at(idx).value;
for (branch_info, branch_cost, branch_requirement) in
zip_eq3(&invocation.branches, &libfunc_cost, &branch_requirements)
{
let future_wallet_value = context.wallet_at(&idx.next(&branch_info.target)).value;
if let BranchCost::WithdrawGas { success: true, .. } = branch_cost {
let withdrawal = specific_context.get_gas_withdrawal(
idx,
branch_cost,
&wallet_value,
future_wallet_value,
);
for (token_type, amount) in SpecificCostContext::to_full_cost_map(withdrawal) {
assert_eq!(
variable_values.insert((*idx, token_type), std::cmp::max(amount, 0)),
None
);
assert_eq!(
variable_values.insert(
(idx.next(&branch_info.target), token_type),
std::cmp::max(-amount, 0),
),
None
);
}
} else if let BranchCost::RedepositGas = branch_cost {
let cost = wallet_value.clone() - branch_requirement.value.clone();
for (token_type, amount) in SpecificCostContext::to_full_cost_map(cost) {
assert_eq!(variable_values.insert((*idx, token_type), amount), None);
}
} else if invocation.branches.len() > 1 {
let cost = wallet_value.clone() - branch_requirement.value.clone();
for (token_type, amount) in SpecificCostContext::to_full_cost_map(cost) {
assert_eq!(
variable_values.insert((idx.next(&branch_info.target), token_type), amount),
None
);
}
}
}
}
pub trait SpecificCostContextTrait<CostType: CostTypeTrait> {
fn should_handle_excess() -> bool;
fn to_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64>;
fn to_full_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64>;
fn get_gas_withdrawal(
&self,
idx: &StatementIdx,
branch_cost: &BranchCost,
wallet_value: &CostType,
future_wallet_value: CostType,
) -> CostType;
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<CostType>;
}
#[derive(Clone, Debug, Default)]
pub struct WalletInfo<CostType: CostTypeTrait> {
value: CostType,
}
impl<CostType: CostTypeTrait> WalletInfo<CostType> {
fn merge(
branch_costs: &[BranchCost],
branches: Vec<Self>,
target_value: Option<&CostType>,
) -> Self {
let n_branches = branches.len();
let mut max_value =
CostType::max(branches.iter().map(|wallet_info| wallet_info.value.clone()));
let is_branch_align = n_branches > 1;
let is_redeposit = matches!(branch_costs[..], [BranchCost::RedepositGas]);
if is_branch_align || is_redeposit {
if let Some(target_value) = target_value {
max_value = CostType::max([max_value, target_value.clone()].into_iter());
}
}
WalletInfo { value: max_value }
}
}
impl<CostType: CostTypeTrait> From<CostType> for WalletInfo<CostType> {
fn from(value: CostType) -> Self {
WalletInfo { value }
}
}
impl<CostType: CostTypeTrait> std::ops::Add for WalletInfo<CostType> {
type Output = Self;
fn add(self, other: Self) -> Self {
WalletInfo { value: self.value + other.value }
}
}
struct CostContext<'a, CostType: CostTypeTrait> {
program: &'a Program,
get_cost_fn: &'a dyn Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
enforced_wallet_values: &'a OrderedHashMap<StatementIdx, CostType>,
costs: UnorderedHashMap<StatementIdx, WalletInfo<CostType>>,
target_values: UnorderedHashMap<StatementIdx, CostType>,
}
impl<'a, CostType: CostTypeTrait> CostContext<'a, CostType> {
fn get_cost(&self, libfunc_id: &ConcreteLibfuncId) -> Vec<BranchCost> {
(self.get_cost_fn)(libfunc_id)
}
fn wallet_at(&self, idx: &StatementIdx) -> WalletInfo<CostType> {
self.wallet_at_ex(idx, true)
}
fn wallet_at_ex(&self, idx: &StatementIdx, with_enforced_values: bool) -> WalletInfo<CostType> {
if with_enforced_values {
if let Some(enforced_wallet_value) = self.enforced_wallet_values.get(idx) {
return WalletInfo::from(enforced_wallet_value.clone());
}
}
self.costs
.get(idx)
.unwrap_or_else(|| panic!("Wallet value for statement {idx} was not yet computed."))
.clone()
}
fn prepare_wallet<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&mut self,
specific_cost_context: &SpecificCostContext,
) -> Result<(), CostError> {
let topological_order =
compute_topological_order(self.program.statements.len(), true, |current_idx| {
match &self.program.get_statement(current_idx).unwrap() {
Statement::Return(_) => {
vec![]
}
Statement::Invocation(invocation) => {
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
get_branch_requirements_dependencies(current_idx, invocation, &libfunc_cost)
.into_iter()
.collect()
}
}
})?;
for current_idx in topological_order {
let res = self.no_cache_compute_wallet_at(¤t_idx, specific_cost_context);
self.costs.insert(current_idx, res.clone());
}
Ok(())
}
fn no_cache_compute_wallet_at<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&mut self,
idx: &StatementIdx,
specific_cost_context: &SpecificCostContext,
) -> WalletInfo<CostType> {
match &self.program.get_statement(idx).unwrap() {
Statement::Return(_) => Default::default(),
Statement::Invocation(invocation) => {
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
specific_cost_context,
&|statement_idx| self.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
);
WalletInfo::merge(&libfunc_cost, branch_requirements, self.target_values.get(idx))
}
}
}
fn compute_target_values<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&self,
specific_cost_context: &SpecificCostContext,
) -> Result<UnorderedHashMap<StatementIdx, CostType>, CostError> {
let topological_order =
compute_topological_order(self.program.statements.len(), false, |current_idx| {
match self.program.get_statement(current_idx).unwrap() {
Statement::Return(_) => {
vec![]
}
Statement::Invocation(invocation) => invocation
.branches
.iter()
.map(|branch_info| current_idx.next(&branch_info.target))
.collect(),
}
})?;
let mut excess = UnorderedHashMap::<StatementIdx, CostType>::default();
let mut finalized_excess_statements = UnorderedHashSet::<StatementIdx>::default();
for idx in topological_order.iter().rev() {
self.handle_excess_at(
idx,
specific_cost_context,
&mut excess,
&mut finalized_excess_statements,
);
}
Ok((0..self.program.statements.len())
.map(|i| {
let idx = StatementIdx(i);
let original_wallet_value = self.wallet_at_ex(&idx, false).value;
(idx, original_wallet_value + excess.get(&idx).cloned().unwrap_or_default())
})
.collect())
}
fn handle_excess_at<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&self,
idx: &StatementIdx,
specific_cost_context: &SpecificCostContext,
excess: &mut UnorderedHashMap<StatementIdx, CostType>,
finalized_excess_statements: &mut UnorderedHashSet<StatementIdx>,
) {
let wallet_value = self.wallet_at_ex(idx, false).value;
if let Some(enforced_wallet_value) = self.enforced_wallet_values.get(idx) {
excess.insert(
*idx,
CostType::rectify(&(enforced_wallet_value.clone() - wallet_value.clone())),
);
}
finalized_excess_statements.insert(*idx);
let current_excess = excess.get(idx).cloned().unwrap_or_default();
let invocation = match &self.program.get_statement(idx).unwrap() {
Statement::Invocation(invocation) => invocation,
Statement::Return(_) => {
return;
}
};
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
let branch_requirements = get_branch_requirements(
specific_cost_context,
&|statement_idx| self.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
);
for (branch_info, branch_cost, branch_requirement) in
zip_eq3(&invocation.branches, &libfunc_cost, branch_requirements)
{
let branch_statement = idx.next(&branch_info.target);
if finalized_excess_statements.contains(&branch_statement) {
return;
}
let future_wallet_value = self.wallet_at(&branch_statement).value;
let mut actual_excess = current_excess.clone();
if invocation.branches.len() > 1 {
if let BranchCost::WithdrawGas { success: true, .. } = branch_cost {
let planned_withdrawal = specific_cost_context.get_gas_withdrawal(
idx,
branch_cost,
&wallet_value,
future_wallet_value,
);
actual_excess = CostType::rectify(&(actual_excess - planned_withdrawal));
} else {
let additional_excess = wallet_value.clone() - branch_requirement.value;
actual_excess = actual_excess + CostType::rectify(&additional_excess);
}
} else if let BranchCost::RedepositGas = branch_cost {
actual_excess = Default::default();
}
match excess.entry(branch_statement) {
hash_map::Entry::Occupied(mut entry) => {
let current_value = entry.get();
entry.insert(CostType::min2(current_value, &actual_excess));
}
hash_map::Entry::Vacant(entry) => {
entry.insert(actual_excess);
}
}
}
}
}
fn compute_topological_order(
n_statements: usize,
detect_cycles: bool,
dependencies_callback: impl Fn(&StatementIdx) -> Vec<StatementIdx>,
) -> Result<Vec<StatementIdx>, CostError> {
get_topological_ordering(
detect_cycles,
(0..n_statements).map(StatementIdx),
n_statements,
|idx| Ok(dependencies_callback(&idx)),
CostError::StatementOutOfBounds,
|_| CostError::UnexpectedCycle,
)
}
pub struct PreCostContext {}
impl SpecificCostContextTrait<PreCost> for PreCostContext {
fn should_handle_excess() -> bool {
false
}
fn to_cost_map(cost: PreCost) -> OrderedHashMap<CostTokenType, i64> {
let res = cost.0;
res.into_iter().map(|(token_type, val)| (token_type, val as i64)).collect()
}
fn to_full_cost_map(cost: PreCost) -> OrderedHashMap<CostTokenType, i64> {
CostTokenType::iter_precost()
.map(|token_type| (*token_type, (*cost.0.get(token_type).unwrap_or(&0)).into()))
.collect()
}
fn get_gas_withdrawal(
&self,
_idx: &StatementIdx,
_branch_cost: &BranchCost,
wallet_value: &PreCost,
future_wallet_value: PreCost,
) -> PreCost {
future_wallet_value - wallet_value.clone()
}
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<PreCost>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<PreCost> {
let branch_cost = match branch_cost {
BranchCost::Regular { const_cost: _, pre_cost } => pre_cost.clone(),
BranchCost::BranchAlign => Default::default(),
BranchCost::FunctionCall { const_cost: _, function } => {
wallet_at_fn(&function.entry_point).value
}
BranchCost::WithdrawGas { const_cost: _, success, with_builtin_costs: _ } => {
if *success {
return Default::default();
} else {
Default::default()
}
}
BranchCost::RedepositGas => Default::default(),
};
let future_wallet_value = wallet_at_fn(&idx.next(&branch_info.target));
WalletInfo::from(branch_cost) + future_wallet_value
}
}
pub struct PostcostContext<'a> {
pub get_ap_change_fn: &'a dyn Fn(&StatementIdx) -> usize,
pub precost_gas_info: &'a GasInfo,
}
impl<'a> SpecificCostContextTrait<i32> for PostcostContext<'a> {
fn should_handle_excess() -> bool {
true
}
fn to_cost_map(cost: i32) -> OrderedHashMap<CostTokenType, i64> {
if cost == 0 { Default::default() } else { Self::to_full_cost_map(cost) }
}
fn to_full_cost_map(cost: i32) -> OrderedHashMap<CostTokenType, i64> {
[(CostTokenType::Const, cost.into())].into_iter().collect()
}
fn get_gas_withdrawal(
&self,
idx: &StatementIdx,
branch_cost: &BranchCost,
wallet_value: &i32,
future_wallet_value: i32,
) -> i32 {
let BranchCost::WithdrawGas { const_cost, success: true, with_builtin_costs } = branch_cost
else {
panic!("Unexpected BranchCost: {:?}.", branch_cost);
};
let withdraw_gas_cost =
self.compute_withdraw_gas_cost(idx, const_cost, *with_builtin_costs);
future_wallet_value + withdraw_gas_cost - *wallet_value
}
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<i32>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<i32> {
let branch_cost_val = match branch_cost {
BranchCost::Regular { const_cost, pre_cost: _ } => const_cost.cost(),
BranchCost::BranchAlign => {
let ap_change = (self.get_ap_change_fn)(idx);
if ap_change == 0 {
0
} else {
ConstCost { steps: 1, holes: ap_change as i32, range_checks: 0 }.cost()
}
}
BranchCost::FunctionCall { const_cost, function } => {
wallet_at_fn(&function.entry_point).value + const_cost.cost()
}
BranchCost::WithdrawGas { const_cost, success, with_builtin_costs } => {
let cost = self.compute_withdraw_gas_cost(idx, const_cost, *with_builtin_costs);
if *success {
return WalletInfo::from(cost);
}
cost
}
BranchCost::RedepositGas => 0,
};
let future_wallet_value = wallet_at_fn(&idx.next(&branch_info.target));
WalletInfo { value: branch_cost_val } + future_wallet_value
}
}
impl<'a> PostcostContext<'a> {
fn compute_withdraw_gas_cost(
&self,
idx: &StatementIdx,
const_cost: &ConstCost,
with_builtin_costs: bool,
) -> i32 {
let mut amount = const_cost.cost();
if with_builtin_costs {
let steps = BuiltinCostWithdrawGasLibfunc::cost_computation_steps(|token_type| {
self.precost_gas_info.variable_values[(*idx, token_type)].into_or_panic()
})
.into_or_panic::<i32>();
amount += ConstCost { steps, ..Default::default() }.cost();
}
amount
}
}