cairo_lang_sierra_to_casm/
annotations.rs

1use std::iter;
2
3use cairo_lang_casm::ap_change::{ApChangeError, ApplyApChange};
4use cairo_lang_sierra::edit_state::{put_results, take_args};
5use cairo_lang_sierra::ids::{ConcreteTypeId, FunctionId, VarId};
6use cairo_lang_sierra::program::{BranchInfo, Function, StatementIdx};
7use cairo_lang_sierra_type_size::TypeSizeMap;
8use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
9use itertools::{chain, zip_eq};
10use thiserror::Error;
11
12use crate::environment::ap_tracking::update_ap_tracking;
13use crate::environment::frame_state::FrameStateError;
14use crate::environment::gas_wallet::{GasWallet, GasWalletError};
15use crate::environment::{
16    ApTracking, ApTrackingBase, Environment, EnvironmentError, validate_environment_equality,
17    validate_final_environment,
18};
19use crate::invocations::{ApTrackingChange, BranchChanges};
20use crate::metadata::Metadata;
21use crate::references::{
22    IntroductionPoint, OutputReferenceValueIntroductionPoint, ReferenceExpression, ReferenceValue,
23    ReferencesError, StatementRefs, build_function_parameters_refs, check_types_match,
24};
25
26#[derive(Error, Debug, Eq, PartialEq)]
27pub enum AnnotationError {
28    #[error("#{statement_idx}: Inconsistent references annotations: {error}")]
29    InconsistentReferencesAnnotation {
30        statement_idx: StatementIdx,
31        error: InconsistentReferenceError,
32    },
33    #[error("#{source_statement_idx}->#{destination_statement_idx}: Annotation was already set.")]
34    AnnotationAlreadySet {
35        source_statement_idx: StatementIdx,
36        destination_statement_idx: StatementIdx,
37    },
38    #[error("#{statement_idx}: {error}")]
39    InconsistentEnvironments { statement_idx: StatementIdx, error: EnvironmentError },
40    #[error("#{statement_idx}: Belongs to two different functions.")]
41    InconsistentFunctionId { statement_idx: StatementIdx },
42    #[error("#{statement_idx}: Invalid convergence.")]
43    InvalidConvergence { statement_idx: StatementIdx },
44    #[error("InvalidStatementIdx")]
45    InvalidStatementIdx,
46    #[error("MissingAnnotationsForStatement")]
47    MissingAnnotationsForStatement(StatementIdx),
48    #[error("#{statement_idx}: {var_id} is undefined.")]
49    MissingReferenceError { statement_idx: StatementIdx, var_id: VarId },
50    #[error("#{source_statement_idx}->#{destination_statement_idx}: {var_id} was overridden.")]
51    OverrideReferenceError {
52        source_statement_idx: StatementIdx,
53        destination_statement_idx: StatementIdx,
54        var_id: VarId,
55    },
56    #[error(transparent)]
57    FrameStateError(#[from] FrameStateError),
58    #[error("#{source_statement_idx}->#{destination_statement_idx}: {error}")]
59    GasWalletError {
60        source_statement_idx: StatementIdx,
61        destination_statement_idx: StatementIdx,
62        error: GasWalletError,
63    },
64    #[error("#{statement_idx}: {error}")]
65    ReferencesError { statement_idx: StatementIdx, error: ReferencesError },
66    #[error("#{statement_idx}: Attempting to enable ap tracking when already enabled.")]
67    ApTrackingAlreadyEnabled { statement_idx: StatementIdx },
68    #[error(
69        "#{source_statement_idx}->#{destination_statement_idx}: Got '{error}' error while moving \
70         {var_id} introduced at {} .", {introduction_point}
71    )]
72    ApChangeError {
73        var_id: VarId,
74        source_statement_idx: StatementIdx,
75        destination_statement_idx: StatementIdx,
76        introduction_point: IntroductionPoint,
77        error: ApChangeError,
78    },
79    #[error("#{source_statement_idx} -> #{destination_statement_idx}: Ap tracking error")]
80    ApTrackingError {
81        source_statement_idx: StatementIdx,
82        destination_statement_idx: StatementIdx,
83        error: ApChangeError,
84    },
85    #[error(
86        "#{statement_idx}: Invalid function ap change annotation. Expected ap tracking: \
87         {expected:?}, got: {actual:?}."
88    )]
89    InvalidFunctionApChange {
90        statement_idx: StatementIdx,
91        expected: ApTracking,
92        actual: ApTracking,
93    },
94}
95
96impl AnnotationError {
97    pub fn stmt_indices(&self) -> Vec<StatementIdx> {
98        match self {
99            AnnotationError::ApChangeError {
100                source_statement_idx,
101                destination_statement_idx,
102                introduction_point,
103                ..
104            } => chain!(
105                [source_statement_idx, destination_statement_idx],
106                &introduction_point.source_statement_idx,
107                [&introduction_point.destination_statement_idx]
108            )
109            .cloned()
110            .collect(),
111            _ => vec![],
112        }
113    }
114}
115
116/// Error representing an inconsistency in the references annotations.
117#[derive(Error, Debug, Eq, PartialEq)]
118pub enum InconsistentReferenceError {
119    #[error("Variable {var} type mismatch. Expected `{expected}`, got `{actual}`.")]
120    TypeMismatch { var: VarId, expected: ConcreteTypeId, actual: ConcreteTypeId },
121    #[error("Variable {var} expression mismatch. Expected `{expected}`, got `{actual}`.")]
122    ExpressionMismatch { var: VarId, expected: ReferenceExpression, actual: ReferenceExpression },
123    #[error("Variable {var} stack index mismatch. Expected `{expected:?}`, got `{actual:?}`.")]
124    StackIndexMismatch { var: VarId, expected: Option<usize>, actual: Option<usize> },
125    #[error("Variable {var} introduction point mismatch. Expected `{expected}`, got `{actual}`.")]
126    IntroductionPointMismatch { var: VarId, expected: IntroductionPoint, actual: IntroductionPoint },
127    #[error("Variable count mismatch.")]
128    VariableCountMismatch,
129    #[error("Missing expected variable {0}.")]
130    VariableMissing(VarId),
131    #[error("Ap tracking is disabled while trying to merge {0}.")]
132    ApTrackingDisabled(VarId),
133}
134
135/// Annotation that represent the state at each program statement.
136#[derive(Clone, Debug)]
137pub struct StatementAnnotations {
138    pub refs: StatementRefs,
139    /// The function id that the statement belongs to.
140    pub function_id: FunctionId,
141    /// Indicates whether convergence in allowed in the given statement.
142    pub convergence_allowed: bool,
143    pub environment: Environment,
144}
145
146/// Annotations of the program statements.
147/// See StatementAnnotations.
148pub struct ProgramAnnotations {
149    /// Optional per statement annotation.
150    per_statement_annotations: Vec<Option<StatementAnnotations>>,
151    /// The indices of the statements that are the targets of backwards jumps.
152    backwards_jump_indices: UnorderedHashSet<StatementIdx>,
153}
154impl ProgramAnnotations {
155    fn new(n_statements: usize, backwards_jump_indices: UnorderedHashSet<StatementIdx>) -> Self {
156        ProgramAnnotations {
157            per_statement_annotations: iter::repeat_with(|| None).take(n_statements).collect(),
158            backwards_jump_indices,
159        }
160    }
161
162    /// Creates a ProgramAnnotations object based on 'n_statements', a given functions list
163    /// and metadata for the program.
164    pub fn create(
165        n_statements: usize,
166        backwards_jump_indices: UnorderedHashSet<StatementIdx>,
167        functions: &[Function],
168        metadata: &Metadata,
169        gas_usage_check: bool,
170        type_sizes: &TypeSizeMap,
171    ) -> Result<Self, AnnotationError> {
172        let mut annotations = ProgramAnnotations::new(n_statements, backwards_jump_indices);
173        for func in functions {
174            annotations.set_or_assert(
175                func.entry_point,
176                StatementAnnotations {
177                    refs: build_function_parameters_refs(func, type_sizes).map_err(|error| {
178                        AnnotationError::ReferencesError { statement_idx: func.entry_point, error }
179                    })?,
180                    function_id: func.id.clone(),
181                    convergence_allowed: false,
182                    environment: Environment::new(if gas_usage_check {
183                        GasWallet::Value(metadata.gas_info.function_costs[&func.id].clone())
184                    } else {
185                        GasWallet::Disabled
186                    }),
187                },
188            )?
189        }
190
191        Ok(annotations)
192    }
193
194    /// Sets the annotations at 'statement_idx' to 'annotations'
195    /// If the annotations for this statement were set previously asserts that the previous
196    /// assignment is consistent with the new assignment and verifies that convergence_allowed
197    /// is true.
198    pub fn set_or_assert(
199        &mut self,
200        statement_idx: StatementIdx,
201        annotations: StatementAnnotations,
202    ) -> Result<(), AnnotationError> {
203        let idx = statement_idx.0;
204        match self.per_statement_annotations.get(idx).ok_or(AnnotationError::InvalidStatementIdx)? {
205            None => self.per_statement_annotations[idx] = Some(annotations),
206            Some(expected_annotations) => {
207                if expected_annotations.function_id != annotations.function_id {
208                    return Err(AnnotationError::InconsistentFunctionId { statement_idx });
209                }
210                validate_environment_equality(
211                    &expected_annotations.environment,
212                    &annotations.environment,
213                )
214                .map_err(|error| AnnotationError::InconsistentEnvironments {
215                    statement_idx,
216                    error,
217                })?;
218                self.test_references_consistency(&annotations, expected_annotations).map_err(
219                    |error| AnnotationError::InconsistentReferencesAnnotation {
220                        statement_idx,
221                        error,
222                    },
223                )?;
224
225                // Note that we ignore annotations here.
226                // a flow cannot converge with a branch target.
227                if !expected_annotations.convergence_allowed {
228                    return Err(AnnotationError::InvalidConvergence { statement_idx });
229                }
230            }
231        };
232        Ok(())
233    }
234
235    /// Checks whether or not `actual` and `expected` references are consistent.
236    /// Returns an error representing the inconsistency.
237    fn test_references_consistency(
238        &self,
239        actual: &StatementAnnotations,
240        expected: &StatementAnnotations,
241    ) -> Result<(), InconsistentReferenceError> {
242        // Check if there is a mismatch at the number of variables.
243        if actual.refs.len() != expected.refs.len() {
244            return Err(InconsistentReferenceError::VariableCountMismatch);
245        }
246        let ap_tracking_enabled =
247            matches!(actual.environment.ap_tracking, ApTracking::Enabled { .. });
248        for (var_id, actual_ref) in actual.refs.iter() {
249            // Check if the variable exists in just one of the branches.
250            let Some(expected_ref) = expected.refs.get(var_id) else {
251                return Err(InconsistentReferenceError::VariableMissing(var_id.clone()));
252            };
253            // Check if the variable doesn't match on type, expression or stack information.
254            if actual_ref.ty != expected_ref.ty {
255                return Err(InconsistentReferenceError::TypeMismatch {
256                    var: var_id.clone(),
257                    expected: expected_ref.ty.clone(),
258                    actual: actual_ref.ty.clone(),
259                });
260            }
261            if actual_ref.expression != expected_ref.expression {
262                return Err(InconsistentReferenceError::ExpressionMismatch {
263                    var: var_id.clone(),
264                    expected: expected_ref.expression.clone(),
265                    actual: actual_ref.expression.clone(),
266                });
267            }
268            if actual_ref.stack_idx != expected_ref.stack_idx {
269                return Err(InconsistentReferenceError::StackIndexMismatch {
270                    var: var_id.clone(),
271                    expected: expected_ref.stack_idx,
272                    actual: actual_ref.stack_idx,
273                });
274            }
275            test_var_consistency(var_id, actual_ref, expected_ref, ap_tracking_enabled)?;
276        }
277        Ok(())
278    }
279
280    /// Returns the result of applying take_args to the StatementAnnotations at statement_idx.
281    /// Can be called only once per item, the item is removed from the annotations, and can no
282    /// longer be used for merges.
283    pub fn get_annotations_after_take_args<'a>(
284        &mut self,
285        statement_idx: StatementIdx,
286        ref_ids: impl Iterator<Item = &'a VarId>,
287    ) -> Result<(StatementAnnotations, Vec<ReferenceValue>), AnnotationError> {
288        let existing = self.per_statement_annotations[statement_idx.0]
289            .as_mut()
290            .ok_or(AnnotationError::MissingAnnotationsForStatement(statement_idx))?;
291        let mut updated = if self.backwards_jump_indices.contains(&statement_idx) {
292            existing.clone()
293        } else {
294            std::mem::replace(
295                existing,
296                StatementAnnotations {
297                    refs: Default::default(),
298                    function_id: existing.function_id.clone(),
299                    // Merging with this data is no longer allowed.
300                    convergence_allowed: false,
301                    environment: existing.environment.clone(),
302                },
303            )
304        };
305        let refs = std::mem::take(&mut updated.refs);
306        let (statement_refs, taken_refs) = take_args(refs, ref_ids).map_err(|error| {
307            AnnotationError::MissingReferenceError { statement_idx, var_id: error.var_id() }
308        })?;
309        updated.refs = statement_refs;
310        Ok((updated, taken_refs))
311    }
312
313    /// Propagates the annotations from `statement_idx` to 'destination_statement_idx'.
314    ///
315    /// `annotations` is the result of calling get_annotations_after_take_args at
316    /// `source_statement_idx` and `branch_changes` are the reference changes at each branch.
317    ///  if `must_set` is true, asserts that destination_statement_idx wasn't annotated before.
318    pub fn propagate_annotations(
319        &mut self,
320        source_statement_idx: StatementIdx,
321        destination_statement_idx: StatementIdx,
322        mut annotations: StatementAnnotations,
323        branch_info: &BranchInfo,
324        branch_changes: BranchChanges,
325        must_set: bool,
326    ) -> Result<(), AnnotationError> {
327        if must_set && self.per_statement_annotations[destination_statement_idx.0].is_some() {
328            return Err(AnnotationError::AnnotationAlreadySet {
329                source_statement_idx,
330                destination_statement_idx,
331            });
332        }
333
334        for (var_id, ref_value) in annotations.refs.iter_mut() {
335            if branch_changes.clear_old_stack {
336                ref_value.stack_idx = None;
337            }
338            ref_value.expression =
339                std::mem::replace(&mut ref_value.expression, ReferenceExpression::zero_sized())
340                    .apply_ap_change(branch_changes.ap_change)
341                    .map_err(|error| AnnotationError::ApChangeError {
342                        var_id: var_id.clone(),
343                        source_statement_idx,
344                        destination_statement_idx,
345                        introduction_point: ref_value.introduction_point.clone(),
346                        error,
347                    })?;
348        }
349        let mut refs = put_results(
350            annotations.refs,
351            zip_eq(
352                &branch_info.results,
353                branch_changes.refs.into_iter().map(|value| ReferenceValue {
354                    expression: value.expression,
355                    ty: value.ty,
356                    stack_idx: value.stack_idx,
357                    introduction_point: match value.introduction_point {
358                        OutputReferenceValueIntroductionPoint::New(output_idx) => {
359                            IntroductionPoint {
360                                source_statement_idx: Some(source_statement_idx),
361                                destination_statement_idx,
362                                output_idx,
363                            }
364                        }
365                        OutputReferenceValueIntroductionPoint::Existing(introduction_point) => {
366                            introduction_point
367                        }
368                    },
369                }),
370            ),
371        )
372        .map_err(|error| AnnotationError::OverrideReferenceError {
373            source_statement_idx,
374            destination_statement_idx,
375            var_id: error.var_id(),
376        })?;
377
378        // Since some variables on the stack may have been consumed by the libfunc, we need to
379        // find the new stack size. This is done by searching from the bottom of the stack until we
380        // find a missing variable.
381        let available_stack_indices: UnorderedHashSet<_> =
382            refs.values().flat_map(|r| r.stack_idx).collect();
383        let new_stack_size_opt = (0..branch_changes.new_stack_size)
384            .find(|i| !available_stack_indices.contains(&(branch_changes.new_stack_size - 1 - i)));
385        let stack_size = if let Some(new_stack_size) = new_stack_size_opt {
386            // The number of stack elements which were removed.
387            let stack_removal = branch_changes.new_stack_size - new_stack_size;
388            for (_, r) in refs.iter_mut() {
389                // Subtract the number of stack elements removed. If the result is negative,
390                // `stack_idx` is set to `None` and the variable is removed from the stack.
391                r.stack_idx =
392                    r.stack_idx.and_then(|stack_idx| stack_idx.checked_sub(stack_removal));
393            }
394            new_stack_size
395        } else {
396            branch_changes.new_stack_size
397        };
398
399        let ap_tracking = match branch_changes.ap_tracking_change {
400            ApTrackingChange::Disable => ApTracking::Disabled,
401            ApTrackingChange::Enable => {
402                if !matches!(annotations.environment.ap_tracking, ApTracking::Disabled) {
403                    return Err(AnnotationError::ApTrackingAlreadyEnabled {
404                        statement_idx: source_statement_idx,
405                    });
406                }
407                ApTracking::Enabled {
408                    ap_change: 0,
409                    base: ApTrackingBase::Statement(destination_statement_idx),
410                }
411            }
412            ApTrackingChange::None => {
413                update_ap_tracking(annotations.environment.ap_tracking, branch_changes.ap_change)
414                    .map_err(|error| AnnotationError::ApTrackingError {
415                        source_statement_idx,
416                        destination_statement_idx,
417                        error,
418                    })?
419            }
420        };
421
422        self.set_or_assert(
423            destination_statement_idx,
424            StatementAnnotations {
425                refs,
426                function_id: annotations.function_id,
427                convergence_allowed: !must_set,
428                environment: Environment {
429                    ap_tracking,
430                    stack_size,
431                    frame_state: annotations.environment.frame_state,
432                    gas_wallet: annotations
433                        .environment
434                        .gas_wallet
435                        .update(branch_changes.gas_change)
436                        .map_err(|error| AnnotationError::GasWalletError {
437                            source_statement_idx,
438                            destination_statement_idx,
439                            error,
440                        })?,
441                },
442            },
443        )
444    }
445
446    /// Validates the ap change and return types in a return statement.
447    pub fn validate_return_properties(
448        &self,
449        statement_idx: StatementIdx,
450        annotations: &StatementAnnotations,
451        functions: &[Function],
452        metadata: &Metadata,
453        return_refs: &[ReferenceValue],
454    ) -> Result<(), AnnotationError> {
455        // TODO(ilya): Don't use linear search.
456        let func = &functions.iter().find(|func| func.id == annotations.function_id).unwrap();
457
458        let expected_ap_tracking = match metadata.ap_change_info.function_ap_change.get(&func.id) {
459            Some(x) => ApTracking::Enabled { ap_change: *x, base: ApTrackingBase::FunctionStart },
460            None => ApTracking::Disabled,
461        };
462        if annotations.environment.ap_tracking != expected_ap_tracking {
463            return Err(AnnotationError::InvalidFunctionApChange {
464                statement_idx,
465                expected: expected_ap_tracking,
466                actual: annotations.environment.ap_tracking,
467            });
468        }
469
470        // Checks that the list of return reference contains has the expected types.
471        check_types_match(return_refs, &func.signature.ret_types)
472            .map_err(|error| AnnotationError::ReferencesError { statement_idx, error })?;
473        Ok(())
474    }
475
476    /// Validates the final annotation in a return statement.
477    pub fn validate_final_annotations(
478        &self,
479        statement_idx: StatementIdx,
480        annotations: &StatementAnnotations,
481        functions: &[Function],
482        metadata: &Metadata,
483        return_refs: &[ReferenceValue],
484    ) -> Result<(), AnnotationError> {
485        self.validate_return_properties(
486            statement_idx,
487            annotations,
488            functions,
489            metadata,
490            return_refs,
491        )?;
492        validate_final_environment(&annotations.environment)
493            .map_err(|error| AnnotationError::InconsistentEnvironments { statement_idx, error })
494    }
495}
496
497/// Checks whether or not the references `actual` and `expected` are consistent and can be merged
498/// in a way that will be re-compilable.
499/// Returns an error representing the inconsistency.
500fn test_var_consistency(
501    var_id: &VarId,
502    actual: &ReferenceValue,
503    expected: &ReferenceValue,
504    ap_tracking_enabled: bool,
505) -> Result<(), InconsistentReferenceError> {
506    // If the variable is on the stack, it can always be merged.
507    if actual.stack_idx.is_some() {
508        return Ok(());
509    }
510    // If the variable is not ap-dependent it can always be merged.
511    // Note: This makes the assumption that empty variables are always mergeable.
512    if actual.expression.can_apply_unknown() {
513        return Ok(());
514    }
515    // Ap tracking must be enabled when merging non-stack ap-dependent variables.
516    if !ap_tracking_enabled {
517        return Err(InconsistentReferenceError::ApTrackingDisabled(var_id.clone()));
518    }
519    // Merged variables must have the same introduction point.
520    if actual.introduction_point == expected.introduction_point {
521        Ok(())
522    } else {
523        Err(InconsistentReferenceError::IntroductionPointMismatch {
524            var: var_id.clone(),
525            expected: expected.introduction_point.clone(),
526            actual: actual.introduction_point.clone(),
527        })
528    }
529}