cairo_lang_sierra_ap_change/
compute.rs

1use cairo_lang_sierra::algorithm::topological_order::reverse_topological_ordering;
2use cairo_lang_sierra::extensions::core::{CoreLibfunc, CoreType};
3use cairo_lang_sierra::extensions::gas::CostTokenType;
4use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId};
5use cairo_lang_sierra::program::{Program, Statement, StatementIdx};
6use cairo_lang_sierra::program_registry::ProgramRegistry;
7use cairo_lang_sierra_type_size::{TypeSizeMap, get_type_size_map};
8use cairo_lang_utils::casts::IntoOrPanic;
9use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
10use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
11
12use crate::ap_change_info::ApChangeInfo;
13use crate::core_libfunc_ap_change::{self, InvocationApChangeInfoProvider};
14use crate::{ApChange, ApChangeError};
15
16/// Helper to implement the `InvocationApChangeInfoProvider` for the equation generation.
17struct InvocationApChangeInfoProviderForEqGen<'a, TokenUsages: Fn(CostTokenType) -> usize> {
18    /// Registry for providing the sizes of the types.
19    type_sizes: &'a TypeSizeMap,
20    /// Closure providing the token usages for the invocation.
21    token_usages: TokenUsages,
22}
23
24impl<TokenUsages: Fn(CostTokenType) -> usize> InvocationApChangeInfoProvider
25    for InvocationApChangeInfoProviderForEqGen<'_, TokenUsages>
26{
27    fn type_size(&self, ty: &ConcreteTypeId) -> usize {
28        self.type_sizes[ty].into_or_panic()
29    }
30
31    fn token_usages(&self, token_type: CostTokenType) -> usize {
32        (self.token_usages)(token_type)
33    }
34}
35
36/// A base to start ap tracking from.
37#[derive(Clone, Debug)]
38enum ApTrackingBase {
39    FunctionStart(FunctionId),
40    #[expect(dead_code)]
41    EnableStatement(StatementIdx),
42}
43
44/// The information for ap tracking of a statement.
45#[derive(Clone, Debug)]
46struct ApTrackingInfo {
47    /// The base tracking from.
48    base: ApTrackingBase,
49    /// The ap-change from the base.
50    ap_change: usize,
51}
52
53/// Helper for calculating the ap-changes of a program.
54struct ApChangeCalcHelper<'a, TokenUsages: Fn(StatementIdx, CostTokenType) -> usize> {
55    /// The program.
56    program: &'a Program,
57    /// The program registry.
58    registry: ProgramRegistry<CoreType, CoreLibfunc>,
59    /// Registry for providing the sizes of the types.
60    type_sizes: TypeSizeMap,
61    /// Closure providing the token usages for the invocation.
62    token_usages: TokenUsages,
63    /// The size of allocated locals until the statement.
64    locals_size: UnorderedHashMap<StatementIdx, usize>,
65    /// The lower bound of an ap-change to the furthest return per statement.
66    known_ap_change_to_return: UnorderedHashMap<StatementIdx, usize>,
67    /// The ap_change of functions with known ap changes.
68    function_ap_change: OrderedHashMap<FunctionId, usize>,
69    /// The ap tracking information per statement.
70    tracking_info: UnorderedHashMap<StatementIdx, ApTrackingInfo>,
71    /// The effective ap change from the statement's base.
72    effective_ap_change_from_base: UnorderedHashMap<StatementIdx, usize>,
73    /// The variables for ap alignment.
74    variable_values: OrderedHashMap<StatementIdx, usize>,
75}
76impl<'a, TokenUsages: Fn(StatementIdx, CostTokenType) -> usize>
77    ApChangeCalcHelper<'a, TokenUsages>
78{
79    /// Creates a new helper.
80    fn new(program: &'a Program, token_usages: TokenUsages) -> Result<Self, ApChangeError> {
81        let registry = ProgramRegistry::<CoreType, CoreLibfunc>::new(program)?;
82        let type_sizes = get_type_size_map(program, &registry).unwrap();
83        Ok(Self {
84            program,
85            registry,
86            type_sizes,
87            token_usages,
88            locals_size: Default::default(),
89            known_ap_change_to_return: Default::default(),
90            function_ap_change: Default::default(),
91            tracking_info: Default::default(),
92            effective_ap_change_from_base: Default::default(),
93            variable_values: Default::default(),
94        })
95    }
96
97    /// Calculates the locals size and function ap changes.
98    fn calc_locals_and_function_ap_changes(&mut self) -> Result<(), ApChangeError> {
99        let rev_ordering = self.known_ap_change_reverse_topological_order()?;
100        for idx in rev_ordering.iter().rev() {
101            self.calc_locals_for_statement(*idx)?;
102        }
103        for idx in rev_ordering {
104            self.calc_known_ap_change_for_statement(idx)?;
105        }
106        self.function_ap_change = self
107            .program
108            .funcs
109            .iter()
110            .filter_map(|f| {
111                self.known_ap_change_to_return
112                    .get(&f.entry_point)
113                    .cloned()
114                    .map(|ap_change| (f.id.clone(), ap_change))
115            })
116            .collect();
117        Ok(())
118    }
119
120    /// Calculates the locals size for a statement.
121    fn calc_locals_for_statement(&mut self, idx: StatementIdx) -> Result<(), ApChangeError> {
122        for (ap_change, target) in self.get_branches(idx)? {
123            match ap_change {
124                ApChange::AtLocalsFinalization(x) => {
125                    self.locals_size.insert(target, self.get_statement_locals(idx) + x);
126                }
127                ApChange::Unknown | ApChange::FinalizeLocals => {}
128                ApChange::FromMetadata
129                | ApChange::FunctionCall(_)
130                | ApChange::EnableApTracking
131                | ApChange::Known(_)
132                | ApChange::DisableApTracking => {
133                    if let Some(locals) = self.locals_size.get(&idx) {
134                        self.locals_size.insert(target, *locals);
135                    }
136                }
137            }
138        }
139        Ok(())
140    }
141
142    /// Calculates the lower bound of an ap-change to the furthest return per statement.
143    /// If it is unknown does not set it.
144    fn calc_known_ap_change_for_statement(
145        &mut self,
146        idx: StatementIdx,
147    ) -> Result<(), ApChangeError> {
148        let mut max_change = 0;
149        for (ap_change, target) in self.get_branches(idx)? {
150            let Some(target_ap_change) = self.known_ap_change_to_return.get(&target) else {
151                return Ok(());
152            };
153            if let Some(ap_change) = self.branch_ap_change(idx, &ap_change, |id| {
154                self.known_ap_change_to_return.get(&self.func_entry_point(id).ok()?).cloned()
155            }) {
156                max_change = max_change.max(target_ap_change + ap_change);
157            } else {
158                return Ok(());
159            };
160        }
161        self.known_ap_change_to_return.insert(idx, max_change);
162        Ok(())
163    }
164
165    /// Returns the topological ordering of the program statements for fully known ap-changes.
166    fn known_ap_change_reverse_topological_order(
167        &self,
168    ) -> Result<Vec<StatementIdx>, ApChangeError> {
169        reverse_topological_ordering(
170            false,
171            (0..self.program.statements.len()).map(StatementIdx),
172            self.program.statements.len(),
173            |idx| {
174                let mut res = vec![];
175                for (ap_change, target) in self.get_branches(idx)? {
176                    res.push(target);
177                    if let ApChange::FunctionCall(id) = ap_change {
178                        res.push(self.func_entry_point(&id)?);
179                    }
180                }
181                Ok(res)
182            },
183            |_| unreachable!("Cycle isn't an error."),
184        )
185    }
186
187    /// Returns the topological ordering of the program statements where tracked ap changes give the
188    /// ordering.
189    fn tracked_ap_change_reverse_topological_order(
190        &self,
191    ) -> Result<Vec<StatementIdx>, ApChangeError> {
192        reverse_topological_ordering(
193            false,
194            (0..self.program.statements.len()).map(StatementIdx),
195            self.program.statements.len(),
196            |idx| {
197                Ok(self
198                    .get_branches(idx)?
199                    .into_iter()
200                    .flat_map(|(ap_change, target)| match ap_change {
201                        ApChange::Unknown => None,
202                        ApChange::FunctionCall(id) => {
203                            if self.function_ap_change.contains_key(&id) {
204                                Some(target)
205                            } else {
206                                None
207                            }
208                        }
209                        ApChange::Known(_)
210                        | ApChange::DisableApTracking
211                        | ApChange::FromMetadata
212                        | ApChange::AtLocalsFinalization(_)
213                        | ApChange::FinalizeLocals
214                        | ApChange::EnableApTracking => Some(target),
215                    })
216                    .collect())
217            },
218            |_| unreachable!("Cycle isn't an error."),
219        )
220    }
221
222    /// Calculates the tracking information for a statement.
223    fn calc_tracking_info_for_statement(&mut self, idx: StatementIdx) -> Result<(), ApChangeError> {
224        for (ap_change, target) in self.get_branches(idx)? {
225            if matches!(ap_change, ApChange::EnableApTracking) {
226                self.tracking_info.insert(
227                    target,
228                    ApTrackingInfo { base: ApTrackingBase::EnableStatement(idx), ap_change: 0 },
229                );
230                continue;
231            }
232            let Some(mut base_info) = self.tracking_info.get(&idx).cloned() else {
233                continue;
234            };
235            if let Some(ap_change) = self
236                .branch_ap_change(idx, &ap_change, |id| self.function_ap_change.get(id).cloned())
237            {
238                base_info.ap_change += ap_change;
239            } else {
240                continue;
241            }
242            match self.tracking_info.entry(target) {
243                Entry::Occupied(e) => {
244                    e.into_mut().ap_change = e.get().ap_change.max(base_info.ap_change);
245                }
246                Entry::Vacant(e) => {
247                    e.insert(base_info);
248                }
249            }
250        }
251        Ok(())
252    }
253
254    /// Calculates the effective ap change for a statement, and the variables for ap alignment.
255    fn calc_effective_ap_change_and_variables_per_statement(
256        &mut self,
257        idx: StatementIdx,
258    ) -> Result<(), ApChangeError> {
259        let Some(base_info) = self.tracking_info.get(&idx).cloned() else {
260            return Ok(());
261        };
262        if matches!(self.program.get_statement(&idx), Some(Statement::Return(_))) {
263            if let ApTrackingBase::FunctionStart(id) = base_info.base {
264                if let Some(func_change) = self.function_ap_change.get(&id) {
265                    self.effective_ap_change_from_base.insert(idx, *func_change);
266                }
267            }
268            return Ok(());
269        }
270        let mut source_ap_change = None;
271        let mut paths_ap_change = vec![];
272        for (ap_change, target) in self.get_branches(idx)? {
273            if matches!(ap_change, ApChange::EnableApTracking) {
274                continue;
275            }
276            let Some(change) = self
277                .branch_ap_change(idx, &ap_change, |id| self.function_ap_change.get(id).cloned())
278            else {
279                source_ap_change = Some(base_info.ap_change);
280                continue;
281            };
282            let Some(target_ap_change) = self.effective_ap_change_from_base.get(&target) else {
283                continue;
284            };
285            let calc_ap_change = target_ap_change - change;
286            paths_ap_change.push((target, calc_ap_change));
287            if let Some(source_ap_change) = &mut source_ap_change {
288                *source_ap_change = (*source_ap_change).min(calc_ap_change);
289            } else {
290                source_ap_change = Some(calc_ap_change);
291            }
292        }
293        if let Some(source_ap_change) = source_ap_change {
294            self.effective_ap_change_from_base.insert(idx, source_ap_change);
295            for (target, path_ap_change) in paths_ap_change {
296                if path_ap_change != source_ap_change {
297                    self.variable_values.insert(target, path_ap_change - source_ap_change);
298                }
299            }
300        }
301        Ok(())
302    }
303
304    /// Gets the actual ap-change of a branch.
305    fn branch_ap_change(
306        &self,
307        idx: StatementIdx,
308        ap_change: &ApChange,
309        func_ap_change: impl Fn(&FunctionId) -> Option<usize>,
310    ) -> Option<usize> {
311        match ap_change {
312            ApChange::Unknown | ApChange::DisableApTracking => None,
313            ApChange::Known(x) => Some(*x),
314            ApChange::FromMetadata
315            | ApChange::AtLocalsFinalization(_)
316            | ApChange::EnableApTracking => Some(0),
317            ApChange::FinalizeLocals => Some(self.get_statement_locals(idx)),
318            ApChange::FunctionCall(id) => func_ap_change(id).map(|x| 2 + x),
319        }
320    }
321
322    /// Returns the locals size for a statement.
323    fn get_statement_locals(&self, idx: StatementIdx) -> usize {
324        self.locals_size.get(&idx).cloned().unwrap_or_default()
325    }
326
327    /// Returns the branches of a statement.
328    fn get_branches(
329        &self,
330        idx: StatementIdx,
331    ) -> Result<Vec<(ApChange, StatementIdx)>, ApChangeError> {
332        Ok(match self.program.get_statement(&idx).unwrap() {
333            Statement::Invocation(invocation) => {
334                let libfunc = self.registry.get_libfunc(&invocation.libfunc_id)?;
335                core_libfunc_ap_change::core_libfunc_ap_change(
336                    libfunc,
337                    &InvocationApChangeInfoProviderForEqGen {
338                        type_sizes: &self.type_sizes,
339                        token_usages: |token_type| (self.token_usages)(idx, token_type),
340                    },
341                )
342                .into_iter()
343                .zip(&invocation.branches)
344                .map(|(ap_change, branch_info)| (ap_change, idx.next(&branch_info.target)))
345                .collect()
346            }
347            Statement::Return(_) => vec![],
348        })
349    }
350
351    /// Returns the entry point of a function.
352    fn func_entry_point(&self, id: &FunctionId) -> Result<StatementIdx, ApChangeError> {
353        Ok(self.registry.get_function(id)?.entry_point)
354    }
355}
356
357/// Calculates ap change information for a given program.
358pub fn calc_ap_changes<TokenUsages: Fn(StatementIdx, CostTokenType) -> usize>(
359    program: &Program,
360    token_usages: TokenUsages,
361) -> Result<ApChangeInfo, ApChangeError> {
362    let mut helper = ApChangeCalcHelper::new(program, token_usages)?;
363    helper.calc_locals_and_function_ap_changes()?;
364    let ap_tracked_reverse_topological_ordering =
365        helper.tracked_ap_change_reverse_topological_order()?;
366    // Setting tracking info for function entry points.
367    for f in &program.funcs {
368        helper.tracking_info.insert(
369            f.entry_point,
370            ApTrackingInfo { base: ApTrackingBase::FunctionStart(f.id.clone()), ap_change: 0 },
371        );
372    }
373    for idx in ap_tracked_reverse_topological_ordering.iter().rev() {
374        helper.calc_tracking_info_for_statement(*idx)?;
375    }
376    for idx in ap_tracked_reverse_topological_ordering {
377        helper.calc_effective_ap_change_and_variables_per_statement(idx)?;
378    }
379    Ok(ApChangeInfo {
380        variable_values: helper.variable_values,
381        function_ap_change: helper.function_ap_change,
382    })
383}