use cairo_lang_eq_solver::Expr;
use cairo_lang_sierra::extensions::circuit::{CircuitInfo, CircuitTypeConcrete, ConcreteCircuit};
use cairo_lang_sierra::extensions::core::{
CoreConcreteLibfunc, CoreLibfunc, CoreType, CoreTypeConcrete,
};
use cairo_lang_sierra::extensions::coupon::CouponConcreteLibfunc;
use cairo_lang_sierra::extensions::gas::{CostTokenType, GasConcreteLibfunc};
use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId};
use cairo_lang_sierra::program::{Program, Statement, StatementIdx};
use cairo_lang_sierra::program_registry::{ProgramRegistry, ProgramRegistryError};
use cairo_lang_sierra_type_size::{TypeSizeMap, get_type_size_map};
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use compute_costs::PostCostTypeEx;
use core_libfunc_cost_base::InvocationCostInfoProvider;
use core_libfunc_cost_expr::CostExprMap;
use cost_expr::Var;
use gas_info::GasInfo;
use generate_equations::StatementFutureCost;
use itertools::Itertools;
use objects::CostInfoProvider;
use thiserror::Error;
pub mod compute_costs;
pub mod core_libfunc_cost;
mod core_libfunc_cost_base;
mod core_libfunc_cost_expr;
mod cost_expr;
pub mod gas_info;
mod generate_equations;
pub mod objects;
mod starknet_libfunc_cost_base;
#[cfg(test)]
mod test;
#[derive(Error, Debug, Eq, PartialEq)]
pub enum CostError {
#[error("error from the program registry")]
ProgramRegistryError(#[from] Box<ProgramRegistryError>),
#[error("found an illegal statement index during cost calculations")]
StatementOutOfBounds(StatementIdx),
#[error("failed solving the symbol tables")]
SolvingGasEquationFailed,
#[error("found an unexpected cycle during cost computation")]
UnexpectedCycle,
#[error("failed to enforce function cost")]
EnforceWalletValueFailed(StatementIdx),
}
struct InvocationCostInfoProviderForEqGen<
'a,
TokenUsages: Fn(CostTokenType) -> usize,
ApChangeVarValue: Fn() -> usize,
> {
type_sizes: &'a TypeSizeMap,
token_usages: TokenUsages,
ap_change_var_value: ApChangeVarValue,
}
impl<TokenUsages: Fn(CostTokenType) -> usize, ApChangeVarValue: Fn() -> usize>
InvocationCostInfoProvider
for InvocationCostInfoProviderForEqGen<'_, TokenUsages, ApChangeVarValue>
{
fn type_size(&self, ty: &ConcreteTypeId) -> usize {
self.type_sizes[ty].into_or_panic()
}
fn token_usages(&self, token_type: CostTokenType) -> usize {
(self.token_usages)(token_type)
}
fn ap_change_var_value(&self) -> usize {
(self.ap_change_var_value)()
}
fn circuit_info(&self, _ty: &ConcreteTypeId) -> &CircuitInfo {
unimplemented!("circuits are not supported for old gas solver");
}
}
pub fn calc_gas_precost_info(
program: &Program,
function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>>,
) -> Result<GasInfo, CostError> {
let cost_provider = ComputeCostInfoProvider::new(program)?;
let registry = ProgramRegistry::<CoreType, CoreLibfunc>::new(program)?;
let mut info = calc_gas_info_inner(
program,
|statement_future_cost, idx, libfunc_id| -> Vec<OrderedHashMap<CostTokenType, Expr<Var>>> {
let libfunc = registry
.get_libfunc(libfunc_id)
.expect("Program registry creation would have already failed.");
core_libfunc_cost_expr::core_libfunc_precost_expr(
statement_future_cost,
idx,
libfunc,
&cost_provider,
)
},
function_set_costs,
®istry,
)?;
for (i, statement) in program.statements.iter().enumerate() {
let Statement::Invocation(invocation) = statement else {
continue;
};
let Ok(libfunc) = registry.get_libfunc(&invocation.libfunc_id) else {
continue;
};
let is_withdraw_gas =
matches!(libfunc, CoreConcreteLibfunc::Gas(GasConcreteLibfunc::WithdrawGas(_)));
let is_refund =
matches!(libfunc, CoreConcreteLibfunc::Coupon(CouponConcreteLibfunc::Refund(_)));
if is_withdraw_gas || is_refund {
for token in CostTokenType::iter_precost() {
assert_eq!(info.variable_values.insert((StatementIdx(i), *token), 0), None);
}
}
}
Ok(info)
}
struct ComputeCostInfoProvider {
pub registry: ProgramRegistry<CoreType, CoreLibfunc>,
pub type_sizes: TypeSizeMap,
}
impl ComputeCostInfoProvider {
fn new(program: &Program) -> Result<Self, Box<ProgramRegistryError>> {
let registry = ProgramRegistry::<CoreType, CoreLibfunc>::new(program)?;
let type_sizes = get_type_size_map(program, ®istry).unwrap();
Ok(Self { registry, type_sizes })
}
}
impl CostInfoProvider for ComputeCostInfoProvider {
fn type_size(&self, ty: &ConcreteTypeId) -> usize {
self.type_sizes[ty].into_or_panic()
}
fn circuit_info(&self, ty: &ConcreteTypeId) -> &CircuitInfo {
let CoreTypeConcrete::Circuit(CircuitTypeConcrete::Circuit(ConcreteCircuit {
circuit_info,
..
})) = self.registry.get_type(ty).unwrap()
else {
panic!("Expected a circuit type, got {ty:?}.")
};
circuit_info
}
}
pub fn compute_precost_info(program: &Program) -> Result<GasInfo, CostError> {
let cost_provider = ComputeCostInfoProvider::new(program)?;
compute_costs::compute_costs(
program,
&(|libfunc_id| {
let core_libfunc = cost_provider
.registry
.get_libfunc(libfunc_id)
.expect("Program registry creation would have already failed.");
core_libfunc_cost_base::core_libfunc_cost(core_libfunc, &cost_provider)
}),
&compute_costs::PreCostContext {},
&Default::default(),
)
}
pub fn calc_gas_postcost_info<ApChangeVarValue: Fn(StatementIdx) -> usize>(
program: &Program,
function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>>,
precost_gas_info: &GasInfo,
ap_change_var_value: ApChangeVarValue,
) -> Result<GasInfo, CostError> {
let registry = ProgramRegistry::<CoreType, CoreLibfunc>::new(program)?;
let type_sizes = get_type_size_map(program, ®istry).unwrap();
let mut info = calc_gas_info_inner(
program,
|statement_future_cost, idx, libfunc_id| {
let libfunc = registry
.get_libfunc(libfunc_id)
.expect("Program registry creation would have already failed.");
core_libfunc_cost_expr::core_libfunc_postcost_expr(
statement_future_cost,
idx,
libfunc,
&InvocationCostInfoProviderForEqGen {
type_sizes: &type_sizes,
token_usages: |token_type| {
precost_gas_info.variable_values[&(*idx, token_type)].into_or_panic()
},
ap_change_var_value: || ap_change_var_value(*idx),
},
)
},
function_set_costs,
®istry,
)?;
for (i, statement) in program.statements.iter().enumerate() {
let Statement::Invocation(invocation) = statement else {
continue;
};
let Ok(libfunc) = registry.get_libfunc(&invocation.libfunc_id) else {
continue;
};
let is_refund =
matches!(libfunc, CoreConcreteLibfunc::Coupon(CouponConcreteLibfunc::Refund(_)));
if is_refund {
assert_eq!(
info.variable_values.insert((StatementIdx(i), CostTokenType::Const), 0),
None
);
}
}
Ok(info)
}
fn calc_gas_info_inner<
GetCost: Fn(&mut dyn StatementFutureCost, &StatementIdx, &ConcreteLibfuncId) -> Vec<CostExprMap>,
>(
program: &Program,
get_cost: GetCost,
function_set_costs: OrderedHashMap<FunctionId, OrderedHashMap<CostTokenType, i32>>,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
) -> Result<GasInfo, CostError> {
let mut equations = generate_equations::generate_equations(program, get_cost)?;
let non_set_cost_func_entry_points: UnorderedHashSet<_> = program
.funcs
.iter()
.filter(|f| !function_set_costs.contains_key(&f.id))
.map(|f| f.entry_point)
.collect();
for (func_id, cost_terms) in function_set_costs {
for token_type in CostTokenType::iter_casm_tokens() {
equations[token_type].push(
Expr::from_var(Var::StatementFuture(
registry.get_function(&func_id)?.entry_point,
*token_type,
)) - Expr::from_const(cost_terms.get(token_type).copied().unwrap_or_default()),
);
}
}
let mut variable_values = OrderedHashMap::default();
let mut function_costs = OrderedHashMap::default();
for (token_type, token_equations) in equations {
let mut minimization_vars = vec![vec![], vec![], vec![]];
for v in token_equations.iter().flat_map(|eq| eq.var_to_coef.keys()).unique() {
minimization_vars[match v {
Var::LibfuncImplicitGasVariable(idx, _) => {
match program.get_statement(idx).unwrap() {
Statement::Invocation(invocation) => {
match registry.get_libfunc(&invocation.libfunc_id).unwrap() {
CoreConcreteLibfunc::BranchAlign(_) => 2,
CoreConcreteLibfunc::Gas(GasConcreteLibfunc::WithdrawGas(_)) => 1,
CoreConcreteLibfunc::Gas(
GasConcreteLibfunc::BuiltinWithdrawGas(_),
) => 0,
CoreConcreteLibfunc::Gas(GasConcreteLibfunc::RedepositGas(_)) => {
continue;
}
_ => unreachable!(
"Gas variables cannot originate from {}.",
invocation.libfunc_id
),
}
}
Statement::Return(_) => continue,
}
}
Var::StatementFuture(idx, _) if non_set_cost_func_entry_points.contains(idx) => 0,
Var::StatementFuture(_, _) => {
continue;
}
}]
.push(v.clone())
}
let solution =
cairo_lang_eq_solver::try_solve_equations(token_equations, minimization_vars)
.ok_or(CostError::SolvingGasEquationFailed)?;
for func in &program.funcs {
let id = &func.id;
if !function_costs.contains_key(id) {
function_costs.insert(id.clone(), OrderedHashMap::default());
}
if let Some(value) = solution.get(&Var::StatementFuture(func.entry_point, token_type)) {
if *value != 0 {
function_costs.get_mut(id).unwrap().insert(token_type, *value);
}
}
}
for (var, value) in solution {
if let Var::LibfuncImplicitGasVariable(idx, var_token_type) = var {
assert_eq!(
token_type, var_token_type,
"Unexpected variable of type {var_token_type:?} while handling {token_type:?}."
);
variable_values.insert((idx, var_token_type), value);
}
}
}
Ok(GasInfo { variable_values, function_costs })
}
pub fn compute_postcost_info<CostType: PostCostTypeEx>(
program: &Program,
get_ap_change_fn: &dyn Fn(&StatementIdx) -> usize,
precost_gas_info: &GasInfo,
enforced_function_costs: &OrderedHashMap<FunctionId, CostType>,
) -> Result<GasInfo, CostError> {
let cost_provider = ComputeCostInfoProvider::new(program)?;
let specific_cost_context =
compute_costs::PostcostContext { get_ap_change_fn, precost_gas_info };
compute_costs::compute_costs(
program,
&(|libfunc_id| {
let core_libfunc = cost_provider
.registry
.get_libfunc(libfunc_id)
.expect("Program registry creation would have already failed.");
core_libfunc_cost_base::core_libfunc_cost(core_libfunc, &cost_provider)
}),
&specific_cost_context,
&enforced_function_costs
.iter()
.map(|(func, val)| {
(
cost_provider
.registry
.get_function(func)
.expect("Program registry creation would have already failed.")
.entry_point,
*val,
)
})
.collect(),
)
}