use std::ops::{Add, Sub};
use cairo_lang_sierra::extensions::gas::{BuiltinCostWithdrawGasLibfunc, CostTokenType};
use cairo_lang_sierra::ids::ConcreteLibfuncId;
use cairo_lang_sierra::program::{BranchInfo, Invocation, Program, Statement, StatementIdx};
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::iterators::zip_eq3;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use itertools::zip_eq;
use crate::gas_info::GasInfo;
use crate::generate_equations::{calculate_reverse_topological_ordering, TopologicalOrderStatus};
use crate::objects::{BranchCost, ConstCost, PreCost};
use crate::CostError;
type VariableValues = OrderedHashMap<(StatementIdx, CostTokenType), i64>;
pub trait CostTypeTrait:
std::fmt::Debug + Default + Clone + Eq + Add<Output = Self> + Sub<Output = Self>
{
fn max(values: impl Iterator<Item = Self>) -> Self;
}
impl CostTypeTrait for i32 {
fn max(values: impl Iterator<Item = Self>) -> Self {
values.max().unwrap_or_default()
}
}
impl CostTypeTrait for PreCost {
fn max(values: impl Iterator<Item = Self>) -> Self {
let mut res = Self::default();
for value in values {
for (token_type, val) in value.0 {
res.0.insert(token_type, std::cmp::max(*res.0.get(&token_type).unwrap_or(&0), val));
}
}
res
}
}
pub fn compute_costs<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
program: &Program,
get_cost_fn: &dyn Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
specific_cost_context: &SpecificCostContext,
) -> Result<GasInfo, CostError> {
let mut context = CostContext { program, costs: UnorderedHashMap::default(), get_cost_fn };
context.prepare_wallet(specific_cost_context)?;
let mut variable_values = VariableValues::default();
for i in 0..program.statements.len() {
analyze_gas_statements(
&context,
specific_cost_context,
&StatementIdx(i),
&mut variable_values,
);
}
let function_costs = program
.funcs
.iter()
.map(|func| {
let res = SpecificCostContext::to_cost_map(context.wallet_at(&func.entry_point).value);
(func.id.clone(), res)
})
.collect();
Ok(GasInfo { variable_values, function_costs })
}
fn get_branch_requirements_dependencies(
idx: &StatementIdx,
invocation: &Invocation,
libfunc_cost: &[BranchCost],
) -> OrderedHashSet<StatementIdx> {
let mut res: OrderedHashSet<StatementIdx> = Default::default();
for (branch_info, branch_cost) in zip_eq(&invocation.branches, libfunc_cost) {
match branch_cost {
BranchCost::FunctionCall { const_cost: _, function } => {
res.insert(function.entry_point);
}
BranchCost::WithdrawGas { const_cost: _, success: true, with_builtin_costs: _ } => {
continue;
}
_ => {}
}
res.insert(idx.next(&branch_info.target));
}
res
}
fn get_branch_requirements<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
specific_context: &SpecificCostContext,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
invocation: &Invocation,
libfunc_cost: &[BranchCost],
) -> Vec<WalletInfo<CostType>> {
zip_eq(&invocation.branches, libfunc_cost)
.map(|(branch_info, branch_cost)| {
specific_context.get_branch_requirement(wallet_at_fn, idx, branch_info, branch_cost)
})
.collect()
}
fn analyze_gas_statements<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
context: &CostContext<'_, CostType>,
specific_context: &SpecificCostContext,
idx: &StatementIdx,
variable_values: &mut VariableValues,
) {
let Statement::Invocation(invocation) = &context.program.get_statement(idx).unwrap() else {
return;
};
let libfunc_cost: Vec<BranchCost> = context.get_cost(&invocation.libfunc_id);
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
specific_context,
&|statement_idx| context.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
);
let wallet_value = context.wallet_at(idx).value;
for (branch_info, branch_cost, branch_requirement) in
zip_eq3(&invocation.branches, &libfunc_cost, &branch_requirements)
{
let future_wallet_value = context.wallet_at(&idx.next(&branch_info.target)).value;
if let BranchCost::WithdrawGas { success: true, .. } = branch_cost {
let withdrawal = specific_context.get_gas_withdrawal(
idx,
branch_cost,
&wallet_value,
future_wallet_value,
);
for (token_type, amount) in SpecificCostContext::to_full_cost_map(withdrawal) {
assert_eq!(
variable_values.insert((*idx, token_type), std::cmp::max(amount, 0)),
None
);
assert_eq!(
variable_values.insert(
(idx.next(&branch_info.target), token_type),
std::cmp::max(-amount, 0),
),
None
);
}
} else if invocation.branches.len() > 1 {
let cost = wallet_value.clone() - branch_requirement.value.clone();
for (token_type, amount) in SpecificCostContext::to_full_cost_map(cost) {
assert_eq!(
variable_values.insert((idx.next(&branch_info.target), token_type), amount),
None
);
}
}
}
}
pub trait SpecificCostContextTrait<CostType: CostTypeTrait> {
fn to_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64>;
fn to_full_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64>;
fn get_gas_withdrawal(
&self,
idx: &StatementIdx,
branch_cost: &BranchCost,
wallet_value: &CostType,
future_wallet_value: CostType,
) -> CostType;
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<CostType>;
}
#[derive(Clone, Debug, Default)]
pub struct WalletInfo<CostType: CostTypeTrait> {
value: CostType,
}
impl<CostType: CostTypeTrait> WalletInfo<CostType> {
fn merge(branches: Vec<Self>) -> Self {
let max_value = CostType::max(branches.iter().map(|wallet_info| wallet_info.value.clone()));
WalletInfo { value: max_value }
}
}
impl<CostType: CostTypeTrait> From<CostType> for WalletInfo<CostType> {
fn from(value: CostType) -> Self {
WalletInfo { value }
}
}
impl<CostType: CostTypeTrait> std::ops::Add for WalletInfo<CostType> {
type Output = Self;
fn add(self, other: Self) -> Self {
WalletInfo { value: self.value + other.value }
}
}
struct CostContext<'a, CostType: CostTypeTrait> {
program: &'a Program,
get_cost_fn: &'a dyn Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
costs: UnorderedHashMap<StatementIdx, WalletInfo<CostType>>,
}
impl<'a, CostType: CostTypeTrait> CostContext<'a, CostType> {
fn get_cost(&self, libfunc_id: &ConcreteLibfuncId) -> Vec<BranchCost> {
(self.get_cost_fn)(libfunc_id)
}
fn wallet_at(&self, idx: &StatementIdx) -> WalletInfo<CostType> {
self.costs
.get(idx)
.unwrap_or_else(|| panic!("Wallet value for statement {idx} was not yet computed."))
.clone()
}
fn prepare_wallet<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&mut self,
specific_cost_context: &SpecificCostContext,
) -> Result<(), CostError> {
let topological_order =
compute_topological_order(self.program.statements.len(), &|current_idx| {
match &self.program.get_statement(current_idx).unwrap() {
Statement::Return(_) => {
vec![]
}
Statement::Invocation(invocation) => {
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
get_branch_requirements_dependencies(current_idx, invocation, &libfunc_cost)
.into_iter()
.collect()
}
}
})?;
for current_idx in topological_order {
let res = self.no_cache_compute_wallet_at(¤t_idx, specific_cost_context);
self.costs.insert(current_idx, res.clone());
}
Ok(())
}
fn no_cache_compute_wallet_at<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&mut self,
idx: &StatementIdx,
specific_cost_context: &SpecificCostContext,
) -> WalletInfo<CostType> {
match &self.program.get_statement(idx).unwrap() {
Statement::Return(_) => Default::default(),
Statement::Invocation(invocation) => {
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
specific_cost_context,
&|statement_idx| self.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
);
WalletInfo::merge(branch_requirements)
}
}
}
}
fn compute_topological_order(
n_statements: usize,
dependencies_callback: &dyn Fn(&StatementIdx) -> Vec<StatementIdx>,
) -> Result<Vec<StatementIdx>, CostError> {
let mut topological_order: Vec<StatementIdx> = Default::default();
let mut status = vec![TopologicalOrderStatus::NotStarted; n_statements];
for idx in 0..n_statements {
calculate_reverse_topological_ordering(
&mut topological_order,
&mut status,
&StatementIdx(idx),
true,
dependencies_callback,
)?;
}
Ok(topological_order)
}
pub struct PreCostContext {}
impl SpecificCostContextTrait<PreCost> for PreCostContext {
fn to_cost_map(cost: PreCost) -> OrderedHashMap<CostTokenType, i64> {
let res = cost.0;
res.into_iter().map(|(token_type, val)| (token_type, val as i64)).collect()
}
fn to_full_cost_map(cost: PreCost) -> OrderedHashMap<CostTokenType, i64> {
CostTokenType::iter_precost()
.map(|token_type| (*token_type, (*cost.0.get(token_type).unwrap_or(&0)).into()))
.collect()
}
fn get_gas_withdrawal(
&self,
_idx: &StatementIdx,
_branch_cost: &BranchCost,
wallet_value: &PreCost,
future_wallet_value: PreCost,
) -> PreCost {
future_wallet_value - wallet_value.clone()
}
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<PreCost>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<PreCost> {
let branch_cost = match branch_cost {
BranchCost::Regular { const_cost: _, pre_cost } => pre_cost.clone(),
BranchCost::BranchAlign => Default::default(),
BranchCost::FunctionCall { const_cost: _, function } => {
wallet_at_fn(&function.entry_point).value
}
BranchCost::WithdrawGas { const_cost: _, success, with_builtin_costs: _ } => {
if *success {
return Default::default();
} else {
Default::default()
}
}
BranchCost::RedepositGas => {
Default::default()
}
};
let future_wallet_value = wallet_at_fn(&idx.next(&branch_info.target));
WalletInfo::from(branch_cost) + future_wallet_value
}
}
pub struct PostcostContext<'a> {
pub get_ap_change_fn: &'a dyn Fn(&StatementIdx) -> usize,
pub precost_gas_info: &'a GasInfo,
}
impl<'a> SpecificCostContextTrait<i32> for PostcostContext<'a> {
fn to_cost_map(cost: i32) -> OrderedHashMap<CostTokenType, i64> {
if cost == 0 { Default::default() } else { Self::to_full_cost_map(cost) }
}
fn to_full_cost_map(cost: i32) -> OrderedHashMap<CostTokenType, i64> {
[(CostTokenType::Const, cost.into())].into_iter().collect()
}
fn get_gas_withdrawal(
&self,
idx: &StatementIdx,
branch_cost: &BranchCost,
wallet_value: &i32,
future_wallet_value: i32,
) -> i32 {
let BranchCost::WithdrawGas { const_cost, success: true, with_builtin_costs } = branch_cost
else {
panic!("Unexpected BranchCost: {:?}.", branch_cost);
};
let withdraw_gas_cost =
self.compute_withdraw_gas_cost(idx, const_cost, *with_builtin_costs);
future_wallet_value + withdraw_gas_cost - *wallet_value
}
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<i32>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<i32> {
let branch_cost_val = match branch_cost {
BranchCost::Regular { const_cost, pre_cost: _ } => const_cost.cost(),
BranchCost::BranchAlign => {
let ap_change = (self.get_ap_change_fn)(idx);
if ap_change == 0 {
0
} else {
ConstCost { steps: 1, holes: ap_change as i32, range_checks: 0 }.cost()
}
}
BranchCost::FunctionCall { const_cost, function } => {
wallet_at_fn(&function.entry_point).value + const_cost.cost()
}
BranchCost::WithdrawGas { const_cost, success, with_builtin_costs } => {
let cost = self.compute_withdraw_gas_cost(idx, const_cost, *with_builtin_costs);
if *success {
return WalletInfo::from(cost);
}
cost
}
BranchCost::RedepositGas => 0,
};
let future_wallet_value = wallet_at_fn(&idx.next(&branch_info.target));
WalletInfo { value: branch_cost_val } + future_wallet_value
}
}
impl<'a> PostcostContext<'a> {
fn compute_withdraw_gas_cost(
&self,
idx: &StatementIdx,
const_cost: &ConstCost,
with_builtin_costs: bool,
) -> i32 {
let mut amount = const_cost.cost();
if with_builtin_costs {
let steps = BuiltinCostWithdrawGasLibfunc::cost_computation_steps(|token_type| {
self.precost_gas_info.variable_values[(*idx, token_type)].into_or_panic()
})
.into_or_panic::<i32>();
amount += ConstCost { steps, ..Default::default() }.cost();
}
amount
}
}