cairo_lang_sierra_ap_change/
compute.rsuse cairo_lang_sierra::algorithm::topological_order::get_topological_ordering;
use cairo_lang_sierra::extensions::core::{CoreLibfunc, CoreType};
use cairo_lang_sierra::extensions::gas::CostTokenType;
use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId};
use cairo_lang_sierra::program::{Program, Statement, StatementIdx};
use cairo_lang_sierra::program_registry::ProgramRegistry;
use cairo_lang_sierra_type_size::{get_type_size_map, TypeSizeMap};
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
use crate::ap_change_info::ApChangeInfo;
use crate::core_libfunc_ap_change::{self, InvocationApChangeInfoProvider};
use crate::{ApChange, ApChangeError};
struct InvocationApChangeInfoProviderForEqGen<'a, TokenUsages: Fn(CostTokenType) -> usize> {
type_sizes: &'a TypeSizeMap,
token_usages: TokenUsages,
}
impl<'a, TokenUsages: Fn(CostTokenType) -> usize> InvocationApChangeInfoProvider
for InvocationApChangeInfoProviderForEqGen<'a, TokenUsages>
{
fn type_size(&self, ty: &ConcreteTypeId) -> usize {
self.type_sizes[ty].into_or_panic()
}
fn token_usages(&self, token_type: CostTokenType) -> usize {
(self.token_usages)(token_type)
}
}
#[derive(Clone, Debug)]
enum ApTrackingBase {
FunctionStart(FunctionId),
#[allow(dead_code)]
EnableStatement(StatementIdx),
}
#[derive(Clone, Debug)]
struct ApTrackingInfo {
base: ApTrackingBase,
ap_change: usize,
}
struct ApChangeCalcHelper<'a, TokenUsages: Fn(StatementIdx, CostTokenType) -> usize> {
program: &'a Program,
registry: ProgramRegistry<CoreType, CoreLibfunc>,
type_sizes: TypeSizeMap,
token_usages: TokenUsages,
locals_size: UnorderedHashMap<StatementIdx, usize>,
known_ap_change_to_return: UnorderedHashMap<StatementIdx, usize>,
function_ap_change: OrderedHashMap<FunctionId, usize>,
tracking_info: UnorderedHashMap<StatementIdx, ApTrackingInfo>,
effective_ap_change_from_base: UnorderedHashMap<StatementIdx, usize>,
variable_values: OrderedHashMap<StatementIdx, usize>,
}
impl<'a, TokenUsages: Fn(StatementIdx, CostTokenType) -> usize>
ApChangeCalcHelper<'a, TokenUsages>
{
fn new(program: &'a Program, token_usages: TokenUsages) -> Result<Self, ApChangeError> {
let registry = ProgramRegistry::<CoreType, CoreLibfunc>::new(program)?;
let type_sizes = get_type_size_map(program, ®istry).unwrap();
Ok(Self {
program,
registry,
type_sizes,
token_usages,
locals_size: Default::default(),
known_ap_change_to_return: Default::default(),
function_ap_change: Default::default(),
tracking_info: Default::default(),
effective_ap_change_from_base: Default::default(),
variable_values: Default::default(),
})
}
fn calc_locals_and_function_ap_changes(&mut self) -> Result<(), ApChangeError> {
let ordering = self.known_ap_change_topological_order()?;
for idx in ordering.iter().rev() {
self.calc_locals_for_statement(*idx)?;
}
for idx in ordering {
self.calc_known_ap_change_for_statement(idx)?;
}
self.function_ap_change = self
.program
.funcs
.iter()
.filter_map(|f| {
self.known_ap_change_to_return
.get(&f.entry_point)
.cloned()
.map(|ap_change| (f.id.clone(), ap_change))
})
.collect();
Ok(())
}
fn calc_locals_for_statement(&mut self, idx: StatementIdx) -> Result<(), ApChangeError> {
for (ap_change, target) in self.get_branches(idx)? {
match ap_change {
ApChange::AtLocalsFinalization(x) => {
self.locals_size.insert(target, self.get_statement_locals(idx) + x);
}
ApChange::Unknown | ApChange::FinalizeLocals => {}
ApChange::FromMetadata
| ApChange::FunctionCall(_)
| ApChange::EnableApTracking
| ApChange::Known(_)
| ApChange::DisableApTracking => {
if let Some(locals) = self.locals_size.get(&idx) {
self.locals_size.insert(target, *locals);
}
}
}
}
Ok(())
}
fn calc_known_ap_change_for_statement(
&mut self,
idx: StatementIdx,
) -> Result<(), ApChangeError> {
let mut max_change = 0;
for (ap_change, target) in self.get_branches(idx)? {
let Some(target_ap_change) = self.known_ap_change_to_return.get(&target) else {
return Ok(());
};
if let Some(ap_change) = self.branch_ap_change(idx, &ap_change, |id| {
self.known_ap_change_to_return.get(&self.func_entry_point(id).ok()?).cloned()
}) {
max_change = max_change.max(target_ap_change + ap_change);
} else {
return Ok(());
};
}
self.known_ap_change_to_return.insert(idx, max_change);
Ok(())
}
fn known_ap_change_topological_order(&self) -> Result<Vec<StatementIdx>, ApChangeError> {
get_topological_ordering(
false,
(0..self.program.statements.len()).map(StatementIdx),
self.program.statements.len(),
|idx| {
let mut res = vec![];
for (ap_change, target) in self.get_branches(idx)? {
res.push(target);
if let ApChange::FunctionCall(id) = ap_change {
res.push(self.func_entry_point(&id)?);
}
}
Ok(res)
},
ApChangeError::StatementOutOfBounds,
|_| unreachable!("Cycle isn't an error."),
)
}
fn tracked_ap_change_topological_order(&self) -> Result<Vec<StatementIdx>, ApChangeError> {
get_topological_ordering(
false,
(0..self.program.statements.len()).map(StatementIdx),
self.program.statements.len(),
|idx| {
Ok(self
.get_branches(idx)?
.into_iter()
.flat_map(|(ap_change, target)| match ap_change {
ApChange::Unknown => None,
ApChange::FunctionCall(id) => {
if self.function_ap_change.contains_key(&id) {
Some(target)
} else {
None
}
}
ApChange::Known(_)
| ApChange::DisableApTracking
| ApChange::FromMetadata
| ApChange::AtLocalsFinalization(_)
| ApChange::FinalizeLocals
| ApChange::EnableApTracking => Some(target),
})
.collect())
},
ApChangeError::StatementOutOfBounds,
|_| unreachable!("Cycle isn't an error."),
)
}
fn calc_tracking_info_for_statement(&mut self, idx: StatementIdx) -> Result<(), ApChangeError> {
for (ap_change, target) in self.get_branches(idx)? {
if matches!(ap_change, ApChange::EnableApTracking) {
self.tracking_info.insert(
target,
ApTrackingInfo { base: ApTrackingBase::EnableStatement(idx), ap_change: 0 },
);
continue;
}
let Some(mut base_info) = self.tracking_info.get(&idx).cloned() else {
continue;
};
if let Some(ap_change) = self
.branch_ap_change(idx, &ap_change, |id| self.function_ap_change.get(id).cloned())
{
base_info.ap_change += ap_change;
} else {
continue;
}
match self.tracking_info.entry(target) {
Entry::Occupied(e) => {
e.into_mut().ap_change = e.get().ap_change.max(base_info.ap_change);
}
Entry::Vacant(e) => {
e.insert(base_info);
}
}
}
Ok(())
}
fn calc_effective_ap_change_and_variables_per_statement(
&mut self,
idx: StatementIdx,
) -> Result<(), ApChangeError> {
let Some(base_info) = self.tracking_info.get(&idx).cloned() else {
return Ok(());
};
if matches!(self.program.get_statement(&idx), Some(Statement::Return(_))) {
if let ApTrackingBase::FunctionStart(id) = base_info.base {
if let Some(func_change) = self.function_ap_change.get(&id) {
self.effective_ap_change_from_base.insert(idx, *func_change);
}
}
return Ok(());
}
let mut source_ap_change = None;
let mut paths_ap_change = vec![];
for (ap_change, target) in self.get_branches(idx)? {
if matches!(ap_change, ApChange::EnableApTracking) {
continue;
}
let Some(change) = self
.branch_ap_change(idx, &ap_change, |id| self.function_ap_change.get(id).cloned())
else {
source_ap_change = Some(base_info.ap_change);
continue;
};
let Some(target_ap_change) = self.effective_ap_change_from_base.get(&target) else {
continue;
};
let calc_ap_change = target_ap_change - change;
paths_ap_change.push((target, calc_ap_change));
if let Some(source_ap_change) = &mut source_ap_change {
*source_ap_change = (*source_ap_change).min(calc_ap_change);
} else {
source_ap_change = Some(calc_ap_change);
}
}
if let Some(source_ap_change) = source_ap_change {
self.effective_ap_change_from_base.insert(idx, source_ap_change);
for (target, path_ap_change) in paths_ap_change {
if path_ap_change != source_ap_change {
self.variable_values.insert(target, path_ap_change - source_ap_change);
}
}
}
Ok(())
}
fn branch_ap_change(
&self,
idx: StatementIdx,
ap_change: &ApChange,
func_ap_change: impl Fn(&FunctionId) -> Option<usize>,
) -> Option<usize> {
match ap_change {
ApChange::Unknown | ApChange::DisableApTracking => None,
ApChange::Known(x) => Some(*x),
ApChange::FromMetadata
| ApChange::AtLocalsFinalization(_)
| ApChange::EnableApTracking => Some(0),
ApChange::FinalizeLocals => Some(self.get_statement_locals(idx)),
ApChange::FunctionCall(id) => func_ap_change(id).map(|x| 2 + x),
}
}
fn get_statement_locals(&self, idx: StatementIdx) -> usize {
self.locals_size.get(&idx).cloned().unwrap_or_default()
}
fn get_branches(
&self,
idx: StatementIdx,
) -> Result<Vec<(ApChange, StatementIdx)>, ApChangeError> {
Ok(match self.program.get_statement(&idx).unwrap() {
Statement::Invocation(invocation) => {
let libfunc = self.registry.get_libfunc(&invocation.libfunc_id)?;
core_libfunc_ap_change::core_libfunc_ap_change(
libfunc,
&InvocationApChangeInfoProviderForEqGen {
type_sizes: &self.type_sizes,
token_usages: |token_type| (self.token_usages)(idx, token_type),
},
)
.into_iter()
.zip(&invocation.branches)
.map(|(ap_change, branch_info)| (ap_change, idx.next(&branch_info.target)))
.collect()
}
Statement::Return(_) => vec![],
})
}
fn func_entry_point(&self, id: &FunctionId) -> Result<StatementIdx, ApChangeError> {
Ok(self.registry.get_function(id)?.entry_point)
}
}
pub fn calc_ap_changes<TokenUsages: Fn(StatementIdx, CostTokenType) -> usize>(
program: &Program,
token_usages: TokenUsages,
) -> Result<ApChangeInfo, ApChangeError> {
let mut helper = ApChangeCalcHelper::new(program, token_usages)?;
helper.calc_locals_and_function_ap_changes()?;
let ap_tracked_topological_ordering = helper.tracked_ap_change_topological_order()?;
for f in &program.funcs {
helper.tracking_info.insert(
f.entry_point,
ApTrackingInfo { base: ApTrackingBase::FunctionStart(f.id.clone()), ap_change: 0 },
);
}
for idx in ap_tracked_topological_ordering.iter().rev() {
helper.calc_tracking_info_for_statement(*idx)?;
}
for idx in ap_tracked_topological_ordering {
helper.calc_effective_ap_change_and_variables_per_statement(idx)?;
}
Ok(ApChangeInfo {
variable_values: helper.variable_values,
function_ap_change: helper.function_ap_change,
})
}