cairo_lang_sierra_to_casm/
annotations.rs

1use std::iter;
2
3use cairo_lang_casm::ap_change::{ApChangeError, ApplyApChange};
4use cairo_lang_sierra::edit_state::{put_results, take_args};
5use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId, VarId};
6use cairo_lang_sierra::program::{BranchInfo, Function, StatementIdx};
7use cairo_lang_sierra_type_size::TypeSizeMap;
8use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
9use itertools::zip_eq;
10use thiserror::Error;
11
12use crate::environment::ap_tracking::update_ap_tracking;
13use crate::environment::frame_state::FrameStateError;
14use crate::environment::gas_wallet::{GasWallet, GasWalletError};
15use crate::environment::{
16    ApTracking, ApTrackingBase, Environment, EnvironmentError, validate_environment_equality,
17    validate_final_environment,
18};
19use crate::invocations::{ApTrackingChange, BranchChanges};
20use crate::metadata::Metadata;
21use crate::references::{
22    IntroductionPoint, OutputReferenceValueIntroductionPoint, ReferenceExpression, ReferenceValue,
23    ReferencesError, StatementRefs, build_function_parameters_refs, check_types_match,
24};
25
26#[derive(Error, Debug, Eq, PartialEq)]
27pub enum AnnotationError {
28    #[error("#{statement_idx}: Inconsistent references annotations: {error}")]
29    InconsistentReferencesAnnotation {
30        statement_idx: StatementIdx,
31        error: InconsistentReferenceError,
32    },
33    #[error("#{source_statement_idx}->#{destination_statement_idx}: Annotation was already set.")]
34    AnnotationAlreadySet {
35        source_statement_idx: StatementIdx,
36        destination_statement_idx: StatementIdx,
37    },
38    #[error("#{statement_idx}: {error}")]
39    InconsistentEnvironments { statement_idx: StatementIdx, error: EnvironmentError },
40    #[error("#{statement_idx}: Belongs to two different functions.")]
41    InconsistentFunctionId { statement_idx: StatementIdx },
42    #[error("#{statement_idx}: Invalid convergence.")]
43    InvalidConvergence { statement_idx: StatementIdx },
44    #[error("InvalidStatementIdx")]
45    InvalidStatementIdx,
46    #[error("MissingAnnotationsForStatement")]
47    MissingAnnotationsForStatement(StatementIdx),
48    #[error("#{statement_idx}: {var_id} is undefined.")]
49    MissingReferenceError { statement_idx: StatementIdx, var_id: VarId },
50    #[error("#{source_statement_idx}->#{destination_statement_idx}: {var_id} was overridden.")]
51    OverrideReferenceError {
52        source_statement_idx: StatementIdx,
53        destination_statement_idx: StatementIdx,
54        var_id: VarId,
55    },
56    #[error(transparent)]
57    FrameStateError(#[from] FrameStateError),
58    #[error("#{source_statement_idx}->#{destination_statement_idx}: {error}")]
59    GasWalletError {
60        source_statement_idx: StatementIdx,
61        destination_statement_idx: StatementIdx,
62        error: GasWalletError,
63    },
64    #[error("#{statement_idx}: {error}")]
65    ReferencesError { statement_idx: StatementIdx, error: ReferencesError },
66    #[error("#{statement_idx}: Attempting to enable ap tracking when already enabled.")]
67    ApTrackingAlreadyEnabled { statement_idx: StatementIdx },
68    #[error(
69        "#{source_statement_idx}->#{destination_statement_idx}: Got '{error}' error while moving \
70         {var_id}."
71    )]
72    ApChangeError {
73        var_id: VarId,
74        source_statement_idx: StatementIdx,
75        destination_statement_idx: StatementIdx,
76        error: ApChangeError,
77    },
78    #[error("#{source_statement_idx} -> #{destination_statement_idx}: Ap tracking error")]
79    ApTrackingError {
80        source_statement_idx: StatementIdx,
81        destination_statement_idx: StatementIdx,
82        error: ApChangeError,
83    },
84    #[error(
85        "#{statement_idx}: Invalid function ap change annotation. Expected ap tracking: \
86         {expected:?}, got: {actual:?}."
87    )]
88    InvalidFunctionApChange {
89        statement_idx: StatementIdx,
90        expected: ApTracking,
91        actual: ApTracking,
92    },
93}
94
95/// Error representing an inconsistency in the references annotations.
96#[derive(Error, Debug, Eq, PartialEq)]
97pub enum InconsistentReferenceError {
98    #[error("Variable {var} type mismatch. Expected `{expected}`, got `{actual}`.")]
99    TypeMismatch { var: VarId, expected: ConcreteTypeId, actual: ConcreteTypeId },
100    #[error("Variable {var} expression mismatch. Expected `{expected}`, got `{actual}`.")]
101    ExpressionMismatch { var: VarId, expected: ReferenceExpression, actual: ReferenceExpression },
102    #[error("Variable {var} stack index mismatch. Expected `{expected:?}`, got `{actual:?}`.")]
103    StackIndexMismatch { var: VarId, expected: Option<usize>, actual: Option<usize> },
104    #[error("Variable {var} introduction point mismatch. Expected `{expected}`, got `{actual}`.")]
105    IntroductionPointMismatch { var: VarId, expected: IntroductionPoint, actual: IntroductionPoint },
106    #[error("Variable count mismatch.")]
107    VariableCountMismatch,
108    #[error("Missing expected variable {0}.")]
109    VariableMissing(VarId),
110    #[error("Ap tracking is disabled while trying to merge {0}.")]
111    ApTrackingDisabled(VarId),
112}
113
114/// Annotation that represent the state at each program statement.
115#[derive(Clone, Debug)]
116pub struct StatementAnnotations {
117    pub refs: StatementRefs,
118    /// The function id that the statement belongs to.
119    pub function_id: FunctionId,
120    /// Indicates whether convergence in allowed in the given statement.
121    pub convergence_allowed: bool,
122    pub environment: Environment,
123}
124
125/// Annotations of the program statements.
126/// See StatementAnnotations.
127pub struct ProgramAnnotations {
128    /// Optional per statement annotation.
129    per_statement_annotations: Vec<Option<StatementAnnotations>>,
130    /// The indices of the statements that are the targets of backwards jumps.
131    backwards_jump_indices: UnorderedHashSet<StatementIdx>,
132}
133impl ProgramAnnotations {
134    fn new(n_statements: usize, backwards_jump_indices: UnorderedHashSet<StatementIdx>) -> Self {
135        ProgramAnnotations {
136            per_statement_annotations: iter::repeat_with(|| None).take(n_statements).collect(),
137            backwards_jump_indices,
138        }
139    }
140
141    /// Creates a ProgramAnnotations object based on 'n_statements', a given functions list
142    /// and metadata for the program.
143    pub fn create(
144        n_statements: usize,
145        backwards_jump_indices: UnorderedHashSet<StatementIdx>,
146        functions: &[Function],
147        metadata: &Metadata,
148        gas_usage_check: bool,
149        type_sizes: &TypeSizeMap,
150    ) -> Result<Self, AnnotationError> {
151        let mut annotations = ProgramAnnotations::new(n_statements, backwards_jump_indices);
152        for func in functions {
153            annotations.set_or_assert(func.entry_point, StatementAnnotations {
154                refs: build_function_parameters_refs(func, type_sizes).map_err(|error| {
155                    AnnotationError::ReferencesError { statement_idx: func.entry_point, error }
156                })?,
157                function_id: func.id.clone(),
158                convergence_allowed: false,
159                environment: Environment::new(if gas_usage_check {
160                    GasWallet::Value(metadata.gas_info.function_costs[&func.id].clone())
161                } else {
162                    GasWallet::Disabled
163                }),
164            })?
165        }
166
167        Ok(annotations)
168    }
169
170    /// Sets the annotations at 'statement_idx' to 'annotations'
171    /// If the annotations for this statement were set previously asserts that the previous
172    /// assignment is consistent with the new assignment and verifies that convergence_allowed
173    /// is true.
174    pub fn set_or_assert(
175        &mut self,
176        statement_idx: StatementIdx,
177        annotations: StatementAnnotations,
178    ) -> Result<(), AnnotationError> {
179        let idx = statement_idx.0;
180        match self.per_statement_annotations.get(idx).ok_or(AnnotationError::InvalidStatementIdx)? {
181            None => self.per_statement_annotations[idx] = Some(annotations),
182            Some(expected_annotations) => {
183                if expected_annotations.function_id != annotations.function_id {
184                    return Err(AnnotationError::InconsistentFunctionId { statement_idx });
185                }
186                validate_environment_equality(
187                    &expected_annotations.environment,
188                    &annotations.environment,
189                )
190                .map_err(|error| AnnotationError::InconsistentEnvironments {
191                    statement_idx,
192                    error,
193                })?;
194                self.test_references_consistency(&annotations, expected_annotations).map_err(
195                    |error| AnnotationError::InconsistentReferencesAnnotation {
196                        statement_idx,
197                        error,
198                    },
199                )?;
200
201                // Note that we ignore annotations here.
202                // a flow cannot converge with a branch target.
203                if !expected_annotations.convergence_allowed {
204                    return Err(AnnotationError::InvalidConvergence { statement_idx });
205                }
206            }
207        };
208        Ok(())
209    }
210
211    /// Checks whether or not `actual` and `expected` references are consistent.
212    /// Returns an error representing the inconsistency.
213    fn test_references_consistency(
214        &self,
215        actual: &StatementAnnotations,
216        expected: &StatementAnnotations,
217    ) -> Result<(), InconsistentReferenceError> {
218        // Check if there is a mismatch at the number of variables.
219        if actual.refs.len() != expected.refs.len() {
220            return Err(InconsistentReferenceError::VariableCountMismatch);
221        }
222        let ap_tracking_enabled =
223            matches!(actual.environment.ap_tracking, ApTracking::Enabled { .. });
224        for (var_id, actual_ref) in actual.refs.iter() {
225            // Check if the variable exists in just one of the branches.
226            let Some(expected_ref) = expected.refs.get(var_id) else {
227                return Err(InconsistentReferenceError::VariableMissing(var_id.clone()));
228            };
229            // Check if the variable doesn't match on type, expression or stack information.
230            if actual_ref.ty != expected_ref.ty {
231                return Err(InconsistentReferenceError::TypeMismatch {
232                    var: var_id.clone(),
233                    expected: expected_ref.ty.clone(),
234                    actual: actual_ref.ty.clone(),
235                });
236            }
237            if actual_ref.expression != expected_ref.expression {
238                return Err(InconsistentReferenceError::ExpressionMismatch {
239                    var: var_id.clone(),
240                    expected: expected_ref.expression.clone(),
241                    actual: actual_ref.expression.clone(),
242                });
243            }
244            if actual_ref.stack_idx != expected_ref.stack_idx {
245                return Err(InconsistentReferenceError::StackIndexMismatch {
246                    var: var_id.clone(),
247                    expected: expected_ref.stack_idx,
248                    actual: actual_ref.stack_idx,
249                });
250            }
251            test_var_consistency(var_id, actual_ref, expected_ref, ap_tracking_enabled)?;
252        }
253        Ok(())
254    }
255
256    /// Returns the result of applying take_args to the StatementAnnotations at statement_idx.
257    /// Can be called only once per item, the item is removed from the annotations, and can no
258    /// longer be used for merges.
259    pub fn get_annotations_after_take_args<'a>(
260        &mut self,
261        statement_idx: StatementIdx,
262        ref_ids: impl Iterator<Item = &'a VarId>,
263    ) -> Result<(StatementAnnotations, Vec<ReferenceValue>), AnnotationError> {
264        let existing = self.per_statement_annotations[statement_idx.0]
265            .as_mut()
266            .ok_or(AnnotationError::MissingAnnotationsForStatement(statement_idx))?;
267        let mut updated = if self.backwards_jump_indices.contains(&statement_idx) {
268            existing.clone()
269        } else {
270            std::mem::replace(existing, StatementAnnotations {
271                refs: Default::default(),
272                function_id: existing.function_id.clone(),
273                // Merging with this data is no longer allowed.
274                convergence_allowed: false,
275                environment: existing.environment.clone(),
276            })
277        };
278        let refs = std::mem::take(&mut updated.refs);
279        let (statement_refs, taken_refs) = take_args(refs, ref_ids).map_err(|error| {
280            AnnotationError::MissingReferenceError { statement_idx, var_id: error.var_id() }
281        })?;
282        updated.refs = statement_refs;
283        Ok((updated, taken_refs))
284    }
285
286    /// Propagates the annotations from `statement_idx` to 'destination_statement_idx'.
287    ///
288    /// `annotations` is the result of calling get_annotations_after_take_args at
289    /// `source_statement_idx` and `branch_changes` are the reference changes at each branch.
290    ///  if `must_set` is true, asserts that destination_statement_idx wasn't annotated before.
291    pub fn propagate_annotations(
292        &mut self,
293        source_statement_idx: StatementIdx,
294        destination_statement_idx: StatementIdx,
295        mut annotations: StatementAnnotations,
296        branch_info: &BranchInfo,
297        branch_changes: BranchChanges,
298        must_set: bool,
299    ) -> Result<(), AnnotationError> {
300        if must_set && self.per_statement_annotations[destination_statement_idx.0].is_some() {
301            return Err(AnnotationError::AnnotationAlreadySet {
302                source_statement_idx,
303                destination_statement_idx,
304            });
305        }
306
307        for (var_id, ref_value) in annotations.refs.iter_mut() {
308            if branch_changes.clear_old_stack {
309                ref_value.stack_idx = None;
310            }
311            ref_value.expression =
312                std::mem::replace(&mut ref_value.expression, ReferenceExpression::zero_sized())
313                    .apply_ap_change(branch_changes.ap_change)
314                    .map_err(|error| AnnotationError::ApChangeError {
315                        var_id: var_id.clone(),
316                        source_statement_idx,
317                        destination_statement_idx,
318                        error,
319                    })?;
320        }
321        let mut refs = put_results(
322            annotations.refs,
323            zip_eq(
324                &branch_info.results,
325                branch_changes.refs.into_iter().map(|value| ReferenceValue {
326                    expression: value.expression,
327                    ty: value.ty,
328                    stack_idx: value.stack_idx,
329                    introduction_point: match value.introduction_point {
330                        OutputReferenceValueIntroductionPoint::New(output_idx) => {
331                            IntroductionPoint {
332                                source_statement_idx: Some(source_statement_idx),
333                                destination_statement_idx,
334                                output_idx,
335                            }
336                        }
337                        OutputReferenceValueIntroductionPoint::Existing(introduction_point) => {
338                            introduction_point
339                        }
340                    },
341                }),
342            ),
343        )
344        .map_err(|error| AnnotationError::OverrideReferenceError {
345            source_statement_idx,
346            destination_statement_idx,
347            var_id: error.var_id(),
348        })?;
349
350        // Since some variables on the stack may have been consumed by the libfunc, we need to
351        // find the new stack size. This is done by searching from the bottom of the stack until we
352        // find a missing variable.
353        let available_stack_indices: UnorderedHashSet<_> =
354            refs.values().flat_map(|r| r.stack_idx).collect();
355        let new_stack_size_opt = (0..branch_changes.new_stack_size)
356            .find(|i| !available_stack_indices.contains(&(branch_changes.new_stack_size - 1 - i)));
357        let stack_size = if let Some(new_stack_size) = new_stack_size_opt {
358            // The number of stack elements which were removed.
359            let stack_removal = branch_changes.new_stack_size - new_stack_size;
360            for (_, r) in refs.iter_mut() {
361                // Subtract the number of stack elements removed. If the result is negative,
362                // `stack_idx` is set to `None` and the variable is removed from the stack.
363                r.stack_idx =
364                    r.stack_idx.and_then(|stack_idx| stack_idx.checked_sub(stack_removal));
365            }
366            new_stack_size
367        } else {
368            branch_changes.new_stack_size
369        };
370
371        let ap_tracking = match branch_changes.ap_tracking_change {
372            ApTrackingChange::Disable => ApTracking::Disabled,
373            ApTrackingChange::Enable => {
374                if !matches!(annotations.environment.ap_tracking, ApTracking::Disabled) {
375                    return Err(AnnotationError::ApTrackingAlreadyEnabled {
376                        statement_idx: source_statement_idx,
377                    });
378                }
379                ApTracking::Enabled {
380                    ap_change: 0,
381                    base: ApTrackingBase::Statement(destination_statement_idx),
382                }
383            }
384            ApTrackingChange::None => {
385                update_ap_tracking(annotations.environment.ap_tracking, branch_changes.ap_change)
386                    .map_err(|error| AnnotationError::ApTrackingError {
387                        source_statement_idx,
388                        destination_statement_idx,
389                        error,
390                    })?
391            }
392        };
393
394        self.set_or_assert(destination_statement_idx, StatementAnnotations {
395            refs,
396            function_id: annotations.function_id,
397            convergence_allowed: !must_set,
398            environment: Environment {
399                ap_tracking,
400                stack_size,
401                frame_state: annotations.environment.frame_state,
402                gas_wallet: annotations
403                    .environment
404                    .gas_wallet
405                    .update(branch_changes.gas_change)
406                    .map_err(|error| AnnotationError::GasWalletError {
407                        source_statement_idx,
408                        destination_statement_idx,
409                        error,
410                    })?,
411            },
412        })
413    }
414
415    /// Validates the ap change and return types in a return statement.
416    pub fn validate_return_properties(
417        &self,
418        statement_idx: StatementIdx,
419        annotations: &StatementAnnotations,
420        functions: &[Function],
421        metadata: &Metadata,
422        return_refs: &[ReferenceValue],
423    ) -> Result<(), AnnotationError> {
424        // TODO(ilya): Don't use linear search.
425        let func = &functions.iter().find(|func| func.id == annotations.function_id).unwrap();
426
427        let expected_ap_tracking = match metadata.ap_change_info.function_ap_change.get(&func.id) {
428            Some(x) => ApTracking::Enabled { ap_change: *x, base: ApTrackingBase::FunctionStart },
429            None => ApTracking::Disabled,
430        };
431        if annotations.environment.ap_tracking != expected_ap_tracking {
432            return Err(AnnotationError::InvalidFunctionApChange {
433                statement_idx,
434                expected: expected_ap_tracking,
435                actual: annotations.environment.ap_tracking,
436            });
437        }
438
439        // Checks that the list of return reference contains has the expected types.
440        check_types_match(return_refs, &func.signature.ret_types)
441            .map_err(|error| AnnotationError::ReferencesError { statement_idx, error })?;
442        Ok(())
443    }
444
445    /// Validates the final annotation in a return statement.
446    pub fn validate_final_annotations(
447        &self,
448        statement_idx: StatementIdx,
449        annotations: &StatementAnnotations,
450        functions: &[Function],
451        metadata: &Metadata,
452        return_refs: &[ReferenceValue],
453    ) -> Result<(), AnnotationError> {
454        self.validate_return_properties(
455            statement_idx,
456            annotations,
457            functions,
458            metadata,
459            return_refs,
460        )?;
461        validate_final_environment(&annotations.environment)
462            .map_err(|error| AnnotationError::InconsistentEnvironments { statement_idx, error })
463    }
464}
465
466/// Checks whether or not the references `actual` and `expected` are consistent and can be merged
467/// in a way that will be re-compilable.
468/// Returns an error representing the inconsistency.
469fn test_var_consistency(
470    var_id: &VarId,
471    actual: &ReferenceValue,
472    expected: &ReferenceValue,
473    ap_tracking_enabled: bool,
474) -> Result<(), InconsistentReferenceError> {
475    // If the variable is on the stack, it can always be merged.
476    if actual.stack_idx.is_some() {
477        return Ok(());
478    }
479    // If the variable is not ap-dependent it can always be merged.
480    // Note: This makes the assumption that empty variables are always mergeable.
481    if actual.expression.can_apply_unknown() {
482        return Ok(());
483    }
484    // Ap tracking must be enabled when merging non-stack ap-dependent variables.
485    if !ap_tracking_enabled {
486        return Err(InconsistentReferenceError::ApTrackingDisabled(var_id.clone()));
487    }
488    // Merged variables must have the same introduction point.
489    if actual.introduction_point == expected.introduction_point {
490        Ok(())
491    } else {
492        Err(InconsistentReferenceError::IntroductionPointMismatch {
493            var: var_id.clone(),
494            expected: expected.introduction_point.clone(),
495            actual: actual.introduction_point.clone(),
496        })
497    }
498}