use std::ops::{Add, Sub};
use cairo_lang_sierra::extensions::gas::CostTokenType;
use cairo_lang_sierra::ids::ConcreteLibfuncId;
use cairo_lang_sierra::program::{BranchInfo, Invocation, Program, Statement, StatementIdx};
use cairo_lang_utils::iterators::zip_eq3;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use itertools::zip_eq;
use crate::gas_info::GasInfo;
use crate::objects::{BranchCost, PreCost};
type VariableValues = OrderedHashMap<(StatementIdx, CostTokenType), i64>;
pub trait CostTypeTrait:
std::fmt::Debug + Default + Clone + Eq + Add<Output = Self> + Sub<Output = Self>
{
fn max(values: impl Iterator<Item = Self>) -> Self;
}
impl CostTypeTrait for PreCost {
fn max(values: impl Iterator<Item = Self>) -> Self {
let mut res = Self::default();
for value in values {
for (token_type, val) in value.0 {
res.0.insert(token_type, std::cmp::max(*res.0.get(&token_type).unwrap_or(&0), val));
}
}
res
}
}
pub fn compute_costs<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
program: &Program,
get_cost_fn: &dyn Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
specific_cost_context: &SpecificCostContext,
) -> GasInfo {
let mut context = CostContext { program, costs: UnorderedHashMap::default(), get_cost_fn };
for i in 0..program.statements.len() {
context.prepare_wallet_at(&StatementIdx(i), specific_cost_context);
}
let mut variable_values = VariableValues::default();
for i in 0..program.statements.len() {
analyze_gas_statements(
&context,
specific_cost_context,
&StatementIdx(i),
&mut variable_values,
);
}
let function_costs = program
.funcs
.iter()
.map(|func| {
let res = SpecificCostContext::to_cost_map(
context.wallet_at(&func.entry_point).get_pure_value(),
);
(func.id.clone(), res)
})
.collect();
GasInfo { variable_values, function_costs }
}
fn get_branch_requirements_dependencies(
idx: &StatementIdx,
invocation: &Invocation,
libfunc_cost: &[BranchCost],
) -> OrderedHashSet<StatementIdx> {
let mut res: OrderedHashSet<StatementIdx> = Default::default();
for (branch_info, branch_cost) in zip_eq(&invocation.branches, libfunc_cost) {
match branch_cost {
BranchCost::FunctionCall { const_cost: _, function } => {
res.insert(function.entry_point);
}
BranchCost::WithdrawGas { const_cost: _, success: true, with_builtin_costs: _ } => {
continue;
}
_ => {}
}
res.insert(idx.next(&branch_info.target));
}
res
}
fn get_branch_requirements<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
specific_context: &SpecificCostContext,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
invocation: &Invocation,
libfunc_cost: &[BranchCost],
) -> Vec<WalletInfo<CostType>> {
zip_eq(&invocation.branches, libfunc_cost)
.map(|(branch_info, branch_cost)| {
specific_context.get_branch_requirement(wallet_at_fn, idx, branch_info, branch_cost)
})
.collect()
}
fn analyze_gas_statements<
CostType: CostTypeTrait,
SpecificCostContext: SpecificCostContextTrait<CostType>,
>(
context: &CostContext<'_, CostType>,
specific_context: &SpecificCostContext,
idx: &StatementIdx,
variable_values: &mut VariableValues,
) {
let Statement::Invocation(invocation) = &context.program.get_statement(idx).unwrap() else {
return;
};
let libfunc_cost: Vec<BranchCost> = context.get_cost(&invocation.libfunc_id);
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
specific_context,
&|statement_idx| context.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
);
let wallet_value = context.wallet_at(idx).get_value();
if invocation.branches.len() > 1 {
for (branch_info, branch_cost, branch_requirement) in
zip_eq3(&invocation.branches, &libfunc_cost, &branch_requirements)
{
let future_wallet_value = context.wallet_at(&idx.next(&branch_info.target)).get_value();
if let BranchCost::WithdrawGas { success: true, .. } = branch_cost {
for (token_type, amount) in specific_context.get_withdraw_gas_values(
branch_cost,
&wallet_value,
future_wallet_value,
) {
assert_eq!(
variable_values.insert((*idx, token_type), std::cmp::max(amount, 0)),
None
);
assert_eq!(
variable_values.insert(
(idx.next(&branch_info.target), token_type),
std::cmp::max(-amount, 0),
),
None
);
}
} else {
for (token_type, amount) in specific_context
.get_branch_align_values(&wallet_value, &branch_requirement.get_value())
{
assert_eq!(
variable_values.insert((idx.next(&branch_info.target), token_type), amount),
None
);
}
}
}
}
}
pub trait SpecificCostContextTrait<CostType: CostTypeTrait> {
fn to_cost_map(cost: CostType) -> OrderedHashMap<CostTokenType, i64>;
fn get_withdraw_gas_values(
&self,
branch_cost: &BranchCost,
wallet_value: &CostType,
future_wallet_value: CostType,
) -> OrderedHashMap<CostTokenType, i64>;
fn get_branch_align_values(
&self,
wallet_value: &CostType,
branch_requirement: &CostType,
) -> OrderedHashMap<CostTokenType, i64>;
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<CostType>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<CostType>;
}
#[derive(Clone, Debug, Default)]
pub struct WalletInfo<CostType: CostTypeTrait> {
value: CostType,
}
impl<CostType: CostTypeTrait> WalletInfo<CostType> {
fn merge(branches: Vec<Self>) -> Self {
let max_value = CostType::max(branches.iter().map(|wallet_info| wallet_info.value.clone()));
WalletInfo { value: max_value }
}
fn get_value(&self) -> CostType {
self.value.clone()
}
fn get_pure_value(self) -> CostType {
self.get_value()
}
}
impl<CostType: CostTypeTrait> From<CostType> for WalletInfo<CostType> {
fn from(value: CostType) -> Self {
WalletInfo { value }
}
}
impl<CostType: CostTypeTrait> std::ops::Add for WalletInfo<CostType> {
type Output = Self;
fn add(self, other: Self) -> Self {
WalletInfo { value: self.value + other.value }
}
}
enum CostComputationStatus<CostType: CostTypeTrait> {
InProgress,
Done(WalletInfo<CostType>),
}
struct CostContext<'a, CostType: CostTypeTrait> {
program: &'a Program,
get_cost_fn: &'a dyn Fn(&ConcreteLibfuncId) -> Vec<BranchCost>,
costs: UnorderedHashMap<StatementIdx, CostComputationStatus<CostType>>,
}
impl<'a, CostType: CostTypeTrait> CostContext<'a, CostType> {
fn get_cost(&self, libfunc_id: &ConcreteLibfuncId) -> Vec<BranchCost> {
(self.get_cost_fn)(libfunc_id)
}
fn wallet_at(&self, idx: &StatementIdx) -> WalletInfo<CostType> {
match self.costs.get(idx) {
Some(CostComputationStatus::Done(res)) => res.clone(),
_ => {
panic!("Wallet value for statement {idx} was not yet computed.")
}
}
}
fn prepare_wallet_at<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&mut self,
idx: &StatementIdx,
specific_cost_context: &SpecificCostContext,
) {
let mut statements_to_visit = vec![*idx];
while let Some(current_idx) = statements_to_visit.last() {
match self.costs.get(current_idx) {
Some(CostComputationStatus::InProgress) => {
let res = self.no_cache_compute_wallet_at(current_idx, specific_cost_context);
self.costs.insert(*current_idx, CostComputationStatus::Done(res.clone()));
statements_to_visit.pop();
continue;
}
Some(CostComputationStatus::Done(_)) => {
statements_to_visit.pop();
continue;
}
None => (),
}
self.costs.insert(*current_idx, CostComputationStatus::InProgress);
match &self.program.get_statement(current_idx).unwrap() {
Statement::Return(_) => {}
Statement::Invocation(invocation) => {
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
let missing_dependencies = get_branch_requirements_dependencies(
current_idx,
invocation,
&libfunc_cost,
)
.into_iter()
.filter(|dep| match self.costs.get(dep) {
None => true,
Some(CostComputationStatus::Done(_)) => false,
Some(CostComputationStatus::InProgress) => {
panic!("Found an unexpected cycle during cost computation.");
}
});
statements_to_visit.extend(missing_dependencies);
}
};
}
}
fn no_cache_compute_wallet_at<SpecificCostContext: SpecificCostContextTrait<CostType>>(
&mut self,
idx: &StatementIdx,
specific_cost_context: &SpecificCostContext,
) -> WalletInfo<CostType> {
match &self.program.get_statement(idx).unwrap() {
Statement::Return(_) => Default::default(),
Statement::Invocation(invocation) => {
let libfunc_cost: Vec<BranchCost> = self.get_cost(&invocation.libfunc_id);
for dependency in
get_branch_requirements_dependencies(idx, invocation, &libfunc_cost)
{
self.prepare_wallet_at(&dependency, specific_cost_context);
}
let branch_requirements: Vec<WalletInfo<CostType>> = get_branch_requirements(
specific_cost_context,
&|statement_idx| self.wallet_at(statement_idx),
idx,
invocation,
&libfunc_cost,
);
WalletInfo::merge(branch_requirements)
}
}
}
}
pub struct PreCostContext {}
impl SpecificCostContextTrait<PreCost> for PreCostContext {
fn to_cost_map(cost: PreCost) -> OrderedHashMap<CostTokenType, i64> {
let res = cost.0;
res.into_iter().map(|(token_type, val)| (token_type, val as i64)).collect()
}
fn get_withdraw_gas_values(
&self,
_branch_cost: &BranchCost,
wallet_value: &PreCost,
future_wallet_value: PreCost,
) -> OrderedHashMap<CostTokenType, i64> {
let res = (future_wallet_value - wallet_value.clone()).0;
CostTokenType::iter_precost()
.map(|token_type| (*token_type, *res.get(token_type).unwrap_or(&0) as i64))
.collect()
}
fn get_branch_align_values(
&self,
wallet_value: &PreCost,
branch_requirement: &PreCost,
) -> OrderedHashMap<CostTokenType, i64> {
let res = (wallet_value.clone() - branch_requirement.clone()).0;
CostTokenType::iter_precost()
.map(|token_type| (*token_type, *res.get(token_type).unwrap_or(&0) as i64))
.collect()
}
fn get_branch_requirement(
&self,
wallet_at_fn: &dyn Fn(&StatementIdx) -> WalletInfo<PreCost>,
idx: &StatementIdx,
branch_info: &BranchInfo,
branch_cost: &BranchCost,
) -> WalletInfo<PreCost> {
let branch_cost = match branch_cost {
BranchCost::Regular { const_cost: _, pre_cost } => pre_cost.clone(),
BranchCost::BranchAlign => Default::default(),
BranchCost::FunctionCall { const_cost: _, function } => {
wallet_at_fn(&function.entry_point).get_pure_value()
}
BranchCost::WithdrawGas { const_cost: _, success, with_builtin_costs: _ } => {
if *success {
return Default::default();
} else {
Default::default()
}
}
BranchCost::RedepositGas => {
Default::default()
}
};
let future_wallet_value = wallet_at_fn(&idx.next(&branch_info.target));
WalletInfo::from(branch_cost) + future_wallet_value
}
}